internode.cu 106 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
3
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include "launch.cuh"
#include "utils.cuh"
6
#include "shmem_wrapper.cuh"
lijian6's avatar
lijian6 committed
7
8
9
10

#ifndef DISABLE_ROCSHMEM

// TODO: fix unroll warnings
lijian6's avatar
lijian6 committed
11
12
13
14
15
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
Chenggang Zhao's avatar
Chenggang Zhao committed
16
17
18
19
20

namespace deep_ep {

namespace internode {

21
extern shmem_team_t cpu_rdma_team;
Chenggang Zhao's avatar
Chenggang Zhao committed
22
23
24
25
26
27
28
29
30

struct SourceMeta {
    int src_rdma_rank, is_token_in_nvl_rank_bits;

    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");

    __forceinline__ SourceMeta() = default;

    // TODO: faster encoding
lijian6's avatar
lijian6 committed
31
32
    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
        src_rdma_rank             = rdma_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
33
        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
lijian6's avatar
lijian6 committed
34
35
#pragma unroll
        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
36
37
38
39
40
41
42
43
44
45
46
47
            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
    }

    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
        return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
    }
};

int get_source_meta_bytes() {
    return sizeof(SourceMeta);
}

lijian6's avatar
lijian6 committed
48
49
50
51
52
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
                                                                     int num_scales,
                                                                     int num_topk_idx,
                                                                     int num_topk_weights) {
    return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
53
54
                                  num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
                                  num_topk_weights * sizeof(float), sizeof(int4)));
Chenggang Zhao's avatar
Chenggang Zhao committed
55
56
}

lijian6's avatar
lijian6 committed
57
58
59
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                    int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
lijian6's avatar
lijian6 committed
60
    // Return `int32_t` offset and count to clean
lijian6's avatar
lijian6 committed
61
    return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
62
63
           num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
           (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
Chenggang Zhao's avatar
Chenggang Zhao committed
64
}
lijian6's avatar
lijian6 committed
65
66
67
68
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
                   int num_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
69
    // Return `int32_t` offset and to clean
lijian6's avatar
lijian6 committed
70
71
    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
                              "Invalid size of `SourceMeta`");
Chenggang Zhao's avatar
Chenggang Zhao committed
72
    return {
lijian6's avatar
lijian6 committed
73
        (num_nvl_recv_buffer_tokens *
74
75
76
        (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
        num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
        num_nvl_ranks * num_sms) / sizeof(int),
lijian6's avatar
lijian6 committed
77
        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
Chenggang Zhao's avatar
Chenggang Zhao committed
78
79
80
81
    };
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
82
83
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
                                                       const int nvl_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
84
85
86
87
    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
88
__forceinline__ __device__ void
lijian6's avatar
lijian6 committed
89
dushmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
lijian6's avatar
lijian6 committed
90
91
    // NOTE: shmem_device_barrier_all() might be an issue as
    // it doesn't follow OpenSHMEM specification on ROCm
92
    kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
93
94
95
96
}

template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
lijian6's avatar
lijian6 committed
97
98
99
100
101
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
                int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
Chenggang Zhao's avatar
Chenggang Zhao committed
102
                const int nvl_clean_offset, const int nvl_num_int_clean,
lijian6's avatar
lijian6 committed
103
104
105
                int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
                int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
                void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
106
                const shmem_team_t rdma_team) {
lijian6's avatar
lijian6 committed
107
108
109
110
    auto sm_id     = static_cast<int>(blockIdx.x);
    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
         lane_id     = get_lane_id();
    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
111
112

    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
113
114
    auto num_rdma_experts = num_experts / kNumRDMARanks,
         num_nvl_experts  = num_rdma_experts / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
115
116
117

    if (sm_id == 0) {
        // Communication with others
lijian6's avatar
lijian6 committed
118
119
        // Global barrier: the first warp do intra-node sync, the second warp do internode sync
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
120
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
121

lijian6's avatar
lijian6 committed
122
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
123

Chenggang Zhao's avatar
Chenggang Zhao committed
124
        // Send numbers of tokens per rank/expert to RDMA ranks
lijian6's avatar
lijian6 committed
125
126
127
        auto rdma_buffer_ptr_int        = reinterpret_cast<int *>(rdma_buffer_ptr);
        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
            rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
Chenggang Zhao's avatar
Chenggang Zhao committed
128
129

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
130
131
        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
                                  rdma_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
132
133
134
135
136
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;

        // Copy to send buffer
        for (int i = thread_id; i < num_ranks; i += num_threads)
lijian6's avatar
lijian6 committed
137
138
            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
                num_tokens_per_rank[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
139
        for (int i = thread_id; i < num_experts; i += num_threads)
lijian6's avatar
lijian6 committed
140
141
142
            rdma_recv_num_tokens_mixed.send_buffer(
                i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
                num_tokens_per_expert[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
143
        if (thread_id < kNumRDMARanks)
lijian6's avatar
lijian6 committed
144
145
146
            rdma_recv_num_tokens_mixed.send_buffer(
                thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
                num_tokens_per_rdma_rank[thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
147
148
149
150
151
        __syncthreads();

        // Issue send
        // TODO: more light fence or barrier or signaling
        // TODO: overlap EP barrier and NVL cleaning
lishen's avatar
lishen committed
152
153
154
        for (int i = warp_id; i < kNumRDMARanks; i += num_warps) {
            if (i != rdma_rank) {
                shmemx_int_put_nbi_warp(
lijian6's avatar
lijian6 committed
155
                rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
lishen's avatar
lishen committed
156
                rdma_recv_num_tokens_mixed.send_buffer(i),
lijian6's avatar
lijian6 committed
157
                NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
lishen's avatar
lishen committed
158
159
160
161
162
163
164
165
166
167
                translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank));
            } else {
                UNROLLED_WARP_COPY(1,
                                   lane_id,
                                   NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
                                   rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
                                   rdma_recv_num_tokens_mixed.send_buffer(i),
                                   ld_volatile_global,
                                   st_na_global);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
168
        }
alpha-baby's avatar
alpha-baby committed
169
        __syncthreads();
lishen's avatar
lishen committed
170

Chenggang Zhao's avatar
Chenggang Zhao committed
171
        if (thread_id == 0)
lijian6's avatar
lijian6 committed
172
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
lijian6's avatar
lijian6 committed
173

Chenggang Zhao's avatar
Chenggang Zhao committed
174
175
176
177
178
        __syncthreads();

        // NVL buffers
        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
179
180
181
182
183
184
185
186
187
188
        auto nvl_reduced_num_tokens_per_expert =
            Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
        auto nvl_send_num_tokens_per_rank =
            AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_send_num_tokens_per_expert =
            AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_rank =
            AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_expert =
            AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
Chenggang Zhao's avatar
Chenggang Zhao committed
189
190

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
191
192
193
194
195
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
                                      nvl_send_num_tokens_per_rank.total_bytes +
                                      nvl_send_num_tokens_per_expert.total_bytes <=
                                  nvl_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
196
197
198
199
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;

        // Reduce number of tokens per expert into the NVL send buffer
lijian6's avatar
lijian6 committed
200
        // TODO: may use DUSHMEM reduction
Chenggang Zhao's avatar
Chenggang Zhao committed
201
202
203
        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
        if (thread_id < num_rdma_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
204
205
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
206
207
208
209
210
211
212
213
                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
            nvl_reduced_num_tokens_per_expert[thread_id] = sum;
        }
        __syncthreads();

        // Reduce RDMA received tokens
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
214
215
216
217
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i) {
                sum +=
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
Chenggang Zhao's avatar
Chenggang Zhao committed
218
219
                recv_rdma_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
220
221
            while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
222
223
224
225
226
            *moe_recv_rdma_counter_mapped = sum;
        }

        // Send numbers of tokens per rank/expert to NVL ranks
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
227
228
229
230
231
232
233
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
            for (int i = 0; i < num_nvl_experts; ++i)
                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
                    nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
234
        }
235
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
236

lijian6's avatar
lijian6 committed
237
        // Reduce number of tokens per rank/expert
Chenggang Zhao's avatar
Chenggang Zhao committed
238
239
240
        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
241
            for (int i = 0; i < num_ranks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
242
243
244
245
                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
                recv_gbl_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
246
247
            while (ld_volatile_global(moe_recv_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
248
249
250
251
            *moe_recv_counter_mapped = sum;
        }
        if (thread_id < num_nvl_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
252
253
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
254
255
                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
lijian6's avatar
lijian6 committed
256
257
            while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
258
259
260
261
            moe_recv_expert_counter_mapped[thread_id] = sum;
        }

        // Finally barrier
lijian6's avatar
lijian6 committed
262
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
263
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
lijian6's avatar
lijian6 committed
264

265
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
266
267
268
269
270
    } else {
        // Calculate meta data
        int dst_rdma_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
271
272
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
273
274
275

            // Iterate over tokens
            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
lijian6's avatar
lijian6 committed
276
277
278
279
280
281
282
283
284
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
                                          "Invalid number of NVL peers");
                auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
                    is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
                auto is_token_in_rank_values =
                    reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
285
286
287
288
289
290
                    per_nvl_rank_count[j] += is_token_in_rank_values[j];
                total_count += (is_token_in_rank_uint64 != 0);
            }

            // Warp reduce
            total_count = warp_reduce_sum(total_count);
lijian6's avatar
lijian6 committed
291
292
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
293
294
295
                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);

            // Write into channel matrix
lijian6's avatar
lijian6 committed
296
297
298
299
300
301
            if (lane_id == 0) {
#pragma unroll
                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
                                                  num_channels +
                                              channel_id] = per_nvl_rank_count[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
302
303
304
305
306
307
                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
            }
        }

        // Calculate prefix sum
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
308
        if (thread_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
309
            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
lijian6's avatar
lijian6 committed
310
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
311
312
313
                prefix_row[i] += prefix_row[i - 1];
        }

lijian6's avatar
lijian6 committed
314
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
315
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
316
317
318
            auto prefix_row = gbl_channel_prefix_matrix +
                              (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
319
320
321
322
323
                prefix_row[i] += prefix_row[i - 1];
        }
    }
}

lijian6's avatar
lijian6 committed
324
325
326
327
328
329
330
331
332
333
334
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
335
                     bool low_latency_mode) {
lijian6's avatar
lijian6 committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                \
    {                                                                                              \
        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks>       \
                                                     : notify_dispatch<false, num_rdma_ranks>;     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks,   \
            num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert,         \
            moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens,             \
            num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second,         \
            nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix,               \
            recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,        \
            rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team);               \
    }                                                                                              \
    break

    constexpr int kNumThreads    = 256;
    const auto    num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
353
354

    // Get clean meta
lijian6's avatar
lijian6 committed
355
356
357
358
359
360
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
lishen's avatar
lishen committed
361
362
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
363
364
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
lishen's avatar
lishen committed
365
366
367
    // add assert origin kernel
    EP_HOST_ASSERT(num_rdma_ranks <= kNumThreads);
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kNumThreads, "Assert NUM_MAX_NVL_PEERS <= kNumThreads");
Chenggang Zhao's avatar
Chenggang Zhao committed
368
369
370
371
372
373
374
375
376
377
378
379

    // Launch kernel
    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}

lijian6's avatar
lijian6 committed
380
381
382
383

template <bool kLowLatencyMode,
          int kNumRDMARanks,
          bool kCachedMode,
lijian6's avatar
lijian6 committed
384
385
          int kNumDispatchRDMASenderWarps,
          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
lijian6's avatar
lijian6 committed
386
387
388
389
390
391
392
393
394
395
396
397
398
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
        SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
        const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
        int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
        int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
        const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
        const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
        int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
        int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
        int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
        int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
        int num_ranks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
399
    enum class WarpRole {
lijian6's avatar
lijian6 committed
400
401
402
403
404
        kRDMASender,            // 从x写入到RDMA发送缓存
        kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
        kRDMAAndNVLForwarder,   // 从RDMA接收缓存转写到ipc nvl缓存
        kForwarderCoordinator,  // 向远端RDMA确认接收
        kNVLReceivers           // 从nvl缓存写入到recv_x
Chenggang Zhao's avatar
Chenggang Zhao committed
405
    };
lijian6's avatar
lijian6 committed
406
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
407
408
409
    __shared__ shmem_ctx_t ctx;
    shmem_wg_ctx_create(&ctx);
#endif
lijian6's avatar
lijian6 committed
410
411
412

    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
lijian6's avatar
lijian6 committed
413
414
415
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
416
417
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

lijian6's avatar
lijian6 committed
418
419
420
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
    EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);

Chenggang Zhao's avatar
Chenggang Zhao committed
421
    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
lijian6's avatar
lijian6 committed
422
423
424
425
426
427
428
429
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASender, -1};
            } else if(warp_id == kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASenderCoordinator, -1};
            }
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            if(warp_id < NUM_MAX_NVL_PEERS) {
Chenggang Zhao's avatar
Chenggang Zhao committed
430
431
432
433
434
                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
            } else {
                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
            }
        } else {
lijian6's avatar
lijian6 committed
435
            return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
Chenggang Zhao's avatar
Chenggang Zhao committed
436
437
        }
    }();
lijian6's avatar
lijian6 committed
438

lijian6's avatar
lijian6 committed
439
    auto warp_role = role_meta.first;
Chenggang Zhao's avatar
Chenggang Zhao committed
440
441
442
    auto target_rank = role_meta.second; // Not applicable for RDMA senders

    // RDMA symmetric layout
lijian6's avatar
lijian6 committed
443
444
445
446
447
448
    auto hidden_bytes             = hidden_int4 * sizeof(int4);
    auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
    auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
449
450

    // NVL buffer layouts
lijian6's avatar
lijian6 committed
451
    // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
Chenggang Zhao's avatar
Chenggang Zhao committed
452
    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
lijian6's avatar
lijian6 committed
453
    int rs_wr_rank = 0, ws_rr_rank = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
454
    if (warp_role == WarpRole::kRDMAAndNVLForwarder)
lijian6's avatar
lijian6 committed
455
        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
456
    if (warp_role == WarpRole::kNVLReceivers)
lijian6's avatar
lijian6 committed
457
        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
458
459

    // Allocate buffers
lijian6's avatar
lijian6 committed
460
461
462
463
464
465
466
467
468
    auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
    auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
469
470

    // RDMA sender warp synchronization
lijian6's avatar
lijian6 committed
471
472
473
    __shared__ volatile int rdma_send_next_token_idx;
    __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
    __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
474

lijian6's avatar
lijian6 committed
475
476
    // NVL and RDMA coordinate Forward warp synchronization
    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
477
    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
lijian6's avatar
lijian6 committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498

    // Place the main logic of your kernel here, using the parameters above.
    if(warp_role == WarpRole::kRDMASender) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
        然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
        最后,它复制相关的数据到对称发送缓冲区。

        kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
        期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
        同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
        */
        // 获取任务范围
        int token_start_idx, token_end_idx;
        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

        // 清理共享内存
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
        if(warp_id == 0 && lane_id == 0) {
            rdma_send_next_token_idx = token_start_idx;
lijian6's avatar
lijian6 committed
499
        }
lijian6's avatar
lijian6 committed
500
501
502
        if(warp_id == 0 && lane_id < kNumRDMARanks) {
            rdma_send_channel_tail[lane_id]      = 0;
            rdma_send_channel_next_tail[lane_id] = 0;
lijian6's avatar
lijian6 committed
503
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
504

lijian6's avatar
lijian6 committed
505
506
507
508
509
510
        // 发送本通道中的令牌数量,通过 `-value - 1` 表示
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
        // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
        // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
        for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
            auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
511
            if (lane_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
512
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
513
            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
514
                dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
515
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
516
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
517
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
lijian6's avatar
lijian6 committed
518
                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
519
            }
lijian6's avatar
lijian6 committed
520

lijian6's avatar
lijian6 committed
521
522
            syncwarp();
            if (dst_rdma_rank != rdma_rank) {
lijian6's avatar
lijian6 committed
523
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
524
525
526
527
528
                shmem_ctx_int_put_nbi_warp(ctx, 
#else
                shmemx_int_put_nbi_warp(
#endif
                rdma_channel_meta.recv_buffer(rdma_rank),
lijian6's avatar
lijian6 committed
529
530
                rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
                translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
531
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
532
        }
533

lijian6's avatar
lijian6 committed
534
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
535
536
537
538
539
        shmem_ctx_quiet(ctx);                
#else
        shmem_fence();
#endif

lijian6's avatar
lijian6 committed
540
541
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
542

lijian6's avatar
lijian6 committed
543
        // 遍历令牌并复制到缓冲区
Chenggang Zhao's avatar
Chenggang Zhao committed
544
        int64_t token_idx;
lijian6's avatar
lijian6 committed
545
546
547
548
        int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
        for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
            // 读取RDMA秩的存在性
Chenggang Zhao's avatar
Chenggang Zhao committed
549
            uint64_t is_token_in_rank_uint64 = 0;
lijian6's avatar
lijian6 committed
550
551
552
            if(lane_id < kNumRDMARanks) {
                is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
553

lijian6's avatar
lijian6 committed
554
555
556
557
            // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
            while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
                // 等待
            }
lijian6's avatar
lijian6 committed
558
            syncwarp();
559

lijian6's avatar
lijian6 committed
560
            // 获取下一个尾部位置
lijian6's avatar
lijian6 committed
561
            int rdma_tail_idx = -1;
lijian6's avatar
lijian6 committed
562
            if(is_token_in_rank_uint64 != 0) {
lijian6's avatar
lijian6 committed
563
                rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
lijian6's avatar
lijian6 committed
564
565
566
567
568

                // 与kForwarderCoordinator相互配合,调节发送数据的频率
                while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
                    cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
569
            }
lijian6's avatar
lijian6 committed
570
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
571

lijian6's avatar
lijian6 committed
572
573
            // 存储RDMA头部以供合并
            if(lane_id < kNumRDMARanks && !kCachedMode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
574
                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
lijian6's avatar
lijian6 committed
575
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
576

lijian6's avatar
lijian6 committed
577
578
579
580
            // 更新最后一个令牌尾部
            if(last_rdma_tail_idx >= 0) {
                st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
            }
lijian6's avatar
lijian6 committed
581
582
            last_rdma_tail_idx = rdma_tail_idx;

lijian6's avatar
lijian6 committed
583
584
585
586
            // 释放顺序锁
            if(lane_id == 0) {
                rdma_send_next_token_idx += 1;
            }
lijian6's avatar
lijian6 committed
587

lijian6's avatar
lijian6 committed
588
            // 广播尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
589
            SourceMeta src_meta;
lijian6's avatar
lijian6 committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
            void* dst_send_buffers[kNumTopkRDMARanks];
            /*
            该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
            */
            #pragma unroll
            for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
                // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
                if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
                    // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
                    //  计算slot_idx在接收缓冲区中的位置
                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;

                    // 存储当前RDMA秩到topk_ranks数组中
                    topk_ranks[num_topk_ranks] = i;

                    // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
lijian6's avatar
lijian6 committed
607
                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
lijian6's avatar
lijian6 committed
608
609
610
611
                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);

                    // 如果当前lane_id等于num_topk_ranks,则更新src_meta
                    if(lane_id == num_topk_ranks) {
lijian6's avatar
lijian6 committed
612
                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
lijian6's avatar
lijian6 committed
613
614
615
616
617
                    }

                    // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
                    // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
                    dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
lijian6's avatar
lijian6 committed
618
                }
lijian6's avatar
lijian6 committed
619
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
620
621
            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);

lijian6's avatar
lijian6 committed
622
623
624
625
626
627
            // 复制 `x` 到对称发送缓冲区
            auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
628
            };
lijian6's avatar
lijian6 committed
629
630
631
632
633
            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
            }
lijian6's avatar
lijian6 committed
634

lijian6's avatar
lijian6 committed
635
636
637
638
639
640
641
642
            // 复制源元数据到对称发送缓冲区
            if(lane_id < num_topk_ranks) {
                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
            }
lijian6's avatar
lijian6 committed
643

lijian6's avatar
lijian6 committed
644
645
646
            // 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_scales; i += kWarpSize) {
lijian6's avatar
lijian6 committed
647
                auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
lijian6's avatar
lijian6 committed
648
649
650
651
652
653
654
655
656
657
658

                // auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
                // auto value = ld_nc_global(x_scales + offset);
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
                }
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
lijian6's avatar
lijian6 committed
659
            }
660

lijian6's avatar
lijian6 committed
661
662
663
            // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
Chenggang Zhao's avatar
Chenggang Zhao committed
664
                auto rank_idx = i / num_topk, copy_idx = i % num_topk;
lijian6's avatar
lijian6 committed
665
                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
666
                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
lijian6's avatar
lijian6 committed
667
668
                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
669
            }
lijian6's avatar
lijian6 committed
670
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
671

lijian6's avatar
lijian6 committed
672
673
674
675
676
677
        // 结尾部分
        // 获取顺序锁
        while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
            // 等待
        }

lijian6's avatar
lijian6 committed
678
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
679

lijian6's avatar
lijian6 committed
680
681
682
683
        // 更新最后一个令牌尾部
        if(last_rdma_tail_idx >= 0) {
            st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
684

lijian6's avatar
lijian6 committed
685
686
687
688
689
690
691
692
693
694
695
696
        // 释放顺序锁
        if(lane_id == 0) {
            rdma_send_next_token_idx += 1;
        }
    } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
        如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
        最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。

        kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
lijian6's avatar
lijian6 committed
697
        dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
lijian6's avatar
lijian6 committed
698
699
700
701
702
703
        */
        if(warp_id > kNumDispatchRDMASenderWarps) {
            return;
        }
        // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
704

lijian6's avatar
lijian6 committed
705
706
707
        // 同步共享内存,确保所有线程在继续之前都达到了这一点
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
708

lijian6's avatar
lijian6 committed
709
        // 计算当前通道需要发送的令牌数
Chenggang Zhao's avatar
Chenggang Zhao committed
710
        int num_tokens_to_send = 0;
lijian6's avatar
lijian6 committed
711
        if(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
712
            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
lijian6's avatar
lijian6 committed
713
714
            if(channel_id > 0)
                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
715
716
        }

lijian6's avatar
lijian6 committed
717
        // 记录上次发出的尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
718
        int last_issued_tail = 0;
lijian6's avatar
lijian6 committed
719
720
721
722
723
724
725
        // 当有任何RDMA秩需要发送令牌时,继续循环
        while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
            for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
                // 计算目标RDMA秩
                int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;

                // 获取同步后的需要发送的令牌数
lijian6's avatar
lijian6 committed
726
                synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
727

lijian6's avatar
lijian6 committed
728
729
730
731
                if(synced_num_tokens_to_send == 0)
                    continue; // 如果没有令牌需要发送,则跳过

                // 读取进度
lijian6's avatar
lijian6 committed
732
                auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
lijian6's avatar
lijian6 committed
733
734
735
736
737
                auto processed_tail          = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
                auto num_tokens_processed    = processed_tail - synced_last_issued_tail;

                // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
                if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
738
739
                    continue;

lijian6's avatar
lijian6 committed
740
741
742
743
744
745
                // 计算本次需要发出的令牌数
                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);

                // 发出RDMA发送请求
                if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
746
                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
747
                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
748
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
749
750
751
752
                    shmem_ctx_schar_put_nbi_warp(ctx,
#else
                    shmemx_int8_put_nbi_warp(
#endif
lijian6's avatar
lijian6 committed
753
754
755
756
757
758
                        rdma_channel_data.recv_buffer(rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        rdma_channel_data.send_buffer(dst_rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        num_bytes_per_rdma_token * num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
759
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
760
761
762
763
                    shmem_ctx_quiet(ctx);                
#else
                    shmem_fence();
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
764
                } else {
lijian6's avatar
lijian6 committed
765
                    // 对于本地RDMA秩,使用较轻的内存屏障
Chenggang Zhao's avatar
Chenggang Zhao committed
766
767
768
                    memory_fence();
                }

lijian6's avatar
lijian6 committed
769
                // 更新尾部位置
lijian6's avatar
lijian6 committed
770
                syncwarp();
lijian6's avatar
lijian6 committed
771
                if(lane_id == dst_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
772
773
                    last_issued_tail += num_tokens_to_issue;
                    num_tokens_to_send -= num_tokens_to_issue;
lijian6's avatar
lijian6 committed
774
                    // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
lijian6's avatar
lijian6 committed
775
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
776
777
778
779
780
                    shmem_ctx_ulong_atomic_add(ctx,
#else
                    shmem_signal_op_add(
#endif
                        rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
lijian6's avatar
lijian6 committed
781
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
782
783
                }
            }
lijian6's avatar
lijian6 committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
        } // while(__any(num_tokens_to_send > 0))
    } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
        它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
        接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
        然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
        最后,它同步头部和尾部索引,并标记通道为退役状态。
        */
        // RDMA消费者和NVL生产者
        const auto dst_nvl_rank          = target_rank;                                       // 目标NVL秩
        const auto dst_rank              = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;      // 目标秩
        const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);              // 目标秩专家开始
        const auto dst_rank_expert_end   = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束

        // 等待计数器到达
Chenggang Zhao's avatar
Chenggang Zhao committed
800
        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
lijian6's avatar
lijian6 committed
801
802
        EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
803
804
805
806
807
808
809
810
811
        if(lane_id < kNumRDMARanks) {
            while(true) {
                // 对应于kRDMASender中的数据写入
                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);                     // 是nvl节点的起始地址
                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);            // 本rdma节点的起始地址
                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);        // 本节点的结束地址
                if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
                    // 通知NVL秩
Chenggang Zhao's avatar
Chenggang Zhao committed
812
                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
lijian6's avatar
lijian6 committed
813
814
815
                    EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);

                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
816
817
                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);

lijian6's avatar
lijian6 committed
818
819
                    // 保存从RDMA通道接收的令牌计数
                    src_rdma_channel_prefix = -meta_2 - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
820
                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;
lijian6's avatar
lijian6 committed
821
822
823
824
825
                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
                    if(!kCachedMode)
                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;

                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
Chenggang Zhao's avatar
Chenggang Zhao committed
826
827
828
829
                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
                    break;
                }

lijian6's avatar
lijian6 committed
830
831
832
833
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
                           channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
Chenggang Zhao's avatar
Chenggang Zhao committed
834
835
836
837
                    trap();
                }
            }
        }
lijian6's avatar
lijian6 committed
838
        syncwarp();
lijian6's avatar
lijian6 committed
839
840

        // 移动缓存的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
841
842
        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;

lijian6's avatar
lijian6 committed
843
844
845
        // 等待共享内存被清理
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
846

lijian6's avatar
lijian6 committed
847
848
849
850
        // 开始准备处理接受数据,直到所有的数据接受完成。
        // 转发从RDMA缓冲区的令牌
        // 注意:总是从本地秩开始
        int src_rdma_rank = sm_id % kNumRDMARanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
851
852
        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
lijian6's avatar
lijian6 committed
853
854
        while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
            // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
lijian6's avatar
lijian6 committed
855
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
856
857
858

            // 用于给kNVLReceivers进行互动,控制数据的传输速度
            while(lane_id == 0) {
lijian6's avatar
lijian6 committed
859
                int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
lijian6's avatar
lijian6 committed
860
                if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
861
                    break;
lijian6's avatar
lijian6 committed
862
                cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
863

lijian6's avatar
lijian6 committed
864
865
866
867
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
868
869
870
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
871
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
872

lijian6's avatar
lijian6 committed
873
            // 找到下一个源RDMA秩(轮询)
lijian6's avatar
lijian6 committed
874
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
875
            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
876
                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
lijian6's avatar
lijian6 committed
877
878
879
880
881
                if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
                    if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));

                    if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
882
                        break;
lijian6's avatar
lijian6 committed
883
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
884
885
                }

lijian6's avatar
lijian6 committed
886
887
888
889
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
Chenggang Zhao's avatar
Chenggang Zhao committed
890
891
892
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
893

lijian6's avatar
lijian6 committed
894
895
            auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
            auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
896

lijian6's avatar
lijian6 committed
897
898
899
900
901
902
903
904
905
906
            // 遍历RDMA缓冲区中的每一个令牌
            for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
                // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
                void* shifted           = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
                auto src_meta           = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
                if(lane_id == src_rdma_rank) {
                    num_tokens_to_recv_from_rdma -= 1;
                }

Chenggang Zhao's avatar
Chenggang Zhao committed
907
                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
lijian6's avatar
lijian6 committed
908
                if(lane_id == src_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
909
910
                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
                    rdma_nvl_token_idx += is_in_dst_nvl_rank;
lijian6's avatar
lijian6 committed
911
                    if(!kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
912
913
                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
                }
lijian6's avatar
lijian6 committed
914
915

                if(!is_in_dst_nvl_rank)
Chenggang Zhao's avatar
Chenggang Zhao committed
916
917
                    continue;

lijian6's avatar
lijian6 committed
918
                // 获取一个空闲槽位
lijian6's avatar
lijian6 committed
919
                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
920

lijian6's avatar
lijian6 committed
921
                // 复制数据
lijian6's avatar
lijian6 committed
922
923
                UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
                                   nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
lijian6's avatar
lijian6 committed
924
925
926
                                   reinterpret_cast<int4*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
lijian6's avatar
lijian6 committed
927

lijian6's avatar
lijian6 committed
928
929
                // 复制源元数据
                if(lane_id == 0)
lijian6's avatar
lijian6 committed
930
                    st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
lijian6's avatar
lijian6 committed
931
                shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
lijian6's avatar
lijian6 committed
932

lijian6's avatar
lijian6 committed
933
                // 复制 `x_scales`
lijian6's avatar
lijian6 committed
934
935
                UNROLLED_WARP_COPY(1, lane_id, num_scales,
                                   nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
lijian6's avatar
lijian6 committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
                                   reinterpret_cast<float*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<float*>(shifted) + num_scales;

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
                    // 读取
                    auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
                    shifted = reinterpret_cast<int*>(shifted) + num_topk;
                    auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);

                    // 转换和写入
                    idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
                    st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
lijian6's avatar
lijian6 committed
950
                    weight_value = idx_value >= 0 ? weight_value : 0.0f;
lijian6's avatar
lijian6 committed
951
                    st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
952
953
                }

lijian6's avatar
lijian6 committed
954
955
                // 在NVL缓冲区不足的情况下,提前停止
                if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
956
957
958
                    src_rdma_tail = i + 1;
            }

lijian6's avatar
lijian6 committed
959
960
961
            // 同步头部索引
            if(lane_id == src_rdma_rank)
                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
962

lijian6's avatar
lijian6 committed
963
            // 移动尾部索引,与kNVLReceivers互相通信使用
lijian6's avatar
lijian6 committed
964
            syncwarp();
lijian6's avatar
lijian6 committed
965
966
967
            if(lane_id == 0) {
                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
968
969
970
        }

        // Retired
lijian6's avatar
lijian6 committed
971
        syncwarp();
lijian6's avatar
lijian6 committed
972
        if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
973
            forward_channel_retired[dst_nvl_rank] = true;
lijian6's avatar
lijian6 committed
974
975
976
977
978
979
980
981
982
        }
    } else if(warp_role == WarpRole::kForwarderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
        它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
        然后,它清理共享内存,并初始化转发通道的头部和退役状态。
        接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
        否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
        */
Chenggang Zhao's avatar
Chenggang Zhao committed
983
        // Extra warps for forwarder coordinator should exit directly
lijian6's avatar
lijian6 committed
984
        if (warp_id > NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
985
986
            return;

lijian6's avatar
lijian6 committed
987
988
989
990
991
992
        // 转发warp协调器
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        // 清理共享内存
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
        for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
993
            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
lijian6's avatar
lijian6 committed
994
        if(lane_id < NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
995
            forward_channel_retired[lane_id] = false;
lijian6's avatar
lijian6 committed
996
997
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
998
999

        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
lijian6's avatar
lijian6 committed
1000
1001
1002

        while(true) {
            // 找到最小的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
1003
            int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1004
#pragma unroll
lijian6's avatar
lijian6 committed
1005
1006
            for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                if(!forward_channel_retired[i])
lijian6's avatar
lijian6 committed
1007
                    min_head = min(min_head, forward_channel_head[i][target_rdma]);
lijian6's avatar
lijian6 committed
1008
1009

            if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1010
                break;
lijian6's avatar
lijian6 committed
1011
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1012

lijian6's avatar
lijian6 committed
1013
1014
            // 更新远程头部
            if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
lijian6's avatar
lijian6 committed
1015
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1016
1017
1018
1019
1020
                shmem_ctx_ulong_atomic_add(ctx,
#else
                shmem_signal_op_add(
#endif
                    rdma_channel_head.buffer(rdma_rank), min_head - last_head,
lijian6's avatar
lijian6 committed
1021
                    translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
1022
1023
                last_head = min_head;
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1024

lijian6's avatar
lijian6 committed
1025
            // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1026
            __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1027
        }
lijian6's avatar
lijian6 committed
1028
1029
1030
1031
1032
1033
1034
1035
    } else if(warp_role == WarpRole::kNVLReceivers) {
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }

        // Place the main logic of your kernel here, using the parameters above.
        // NVL消费者
        // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
Chenggang Zhao's avatar
Chenggang Zhao committed
1036
        int src_nvl_rank = target_rank, total_offset = 0;
lijian6's avatar
lijian6 committed
1037
1038
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1039
1040
            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];

lijian6's avatar
lijian6 committed
1041
1042
        // 接收通道偏移
        int start_offset = 0, end_offset = 0, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
1043
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1044
1045

        while(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1046
            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1047
            end_offset   = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1048
            if(start_offset < 0 && end_offset < 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1049
1050
1051
1052
                start_offset = -start_offset - 1, end_offset = -end_offset - 1;
                total_offset += start_offset;
                break;
            }
lijian6's avatar
lijian6 committed
1053
1054
1055
1056
            // 超时检查
            if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
                        channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
1057
1058
1059
                trap();
            }
        }
lijian6's avatar
lijian6 committed
1060

Chenggang Zhao's avatar
Chenggang Zhao committed
1061
1062
        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

lijian6's avatar
lijian6 committed
1063
1064
1065
        // 保存以供合并使用
        if(lane_id < kNumRDMARanks && !kCachedMode)
            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
lijian6's avatar
lijian6 committed
1066
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1067
1068

        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1069
1070
        while(num_tokens_to_recv > 0) {
            // 通过通道0检查通道状态
lijian6's avatar
lijian6 committed
1071
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1072
1073
1074
            while(lane_id == 0) {
                // 准备复制
                if(cached_channel_head_idx != cached_channel_tail_idx)
Chenggang Zhao's avatar
Chenggang Zhao committed
1075
                    break;
lijian6's avatar
lijian6 committed
1076
1077
1078
1079
1080
                cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
                            channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1081
1082
1083
1084
                    trap();
                }
            }

lijian6's avatar
lijian6 committed
1085
            // 同步队列尾部
lijian6's avatar
lijian6 committed
1086
1087
            cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);

lijian6's avatar
lijian6 committed
1088
            // 复制数据
Chenggang Zhao's avatar
Chenggang Zhao committed
1089
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1090
1091
1092
1093
            for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
                auto meta               = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
                int64_t recv_token_idx  = shfl_sync(total_offset, meta.src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1094
1095
                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;

lijian6's avatar
lijian6 committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
                // 复制数据
                UNROLLED_WARP_COPY(5,
                                lane_id,
                                hidden_int4,
                                recv_x + recv_token_idx * hidden_int4,
                                nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
                                ld_nc_global,
                                st_na_global);

                // 复制源元数据
                if(lane_id == 0 && !kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
1107
                    st_na_global(recv_src_meta + recv_token_idx, meta);
lijian6's avatar
lijian6 committed
1108

lijian6's avatar
lijian6 committed
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
                // 复制比例
                UNROLLED_WARP_COPY(1,
                                lane_id,
                                num_scales,
                                recv_x_scales + recv_token_idx * num_scales,
                                nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
                                ld_nc_global,
                                st_na_global);

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
lijian6's avatar
lijian6 committed
1120
1121
                    auto recv_idx   = recv_token_idx * num_topk + lane_id;
                    auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
lijian6's avatar
lijian6 committed
1122
1123
                    st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
                    st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1124
1125
1126
                }
            }

lijian6's avatar
lijian6 committed
1127
            // 移动队列
lijian6's avatar
lijian6 committed
1128
            syncwarp();
lijian6's avatar
lijian6 committed
1129
            if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1130
                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
lijian6's avatar
lijian6 committed
1131
1132
            }
        } // while(num_tokens_to_recv > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1133
    }
lijian6's avatar
lijian6 committed
1134
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1135
1136
    shmem_wg_ctx_destroy(&ctx);
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1137
1138
}

lijian6's avatar
lijian6 committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1152
    constexpr int kNumDispatchRDMASenderWarps = 7;
lijian6's avatar
lijian6 committed
1153
1154
    // Make sure never OOB
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
lijian6's avatar
lijian6 committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                       \
    {                                                                                              \
        auto dispatch_func =                                                                       \
            low_latency_mode                                                                       \
                ? (is_cached_dispatch                                                              \
                       ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>         \
                       : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>)       \
                : (is_cached_dispatch                                                              \
                       ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>        \
                       : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>);     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx,   \
            recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta),                      \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head,   \
            send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix,        \
            rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix,      \
            recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales,       \
            num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,       \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
1180
1181
    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

lijian6's avatar
lijian6 committed
1182
1183
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
                        (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1184
1185
1186
1187
    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
1188
template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
1189
__global__ void __launch_bounds__(1024, 1)
lijian6's avatar
lijian6 committed
1190
1191
1192
1193
1194
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
              const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
              int num_channels, const int *rdma_channel_prefix_matrix,
              const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
              void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
1195
              bool is_cached_dispatch, const shmem_team_t rdma_team) {
lijian6's avatar
lijian6 committed
1196
1197
    auto sm_id       = static_cast<int>(blockIdx.x);
    auto thread_id   = static_cast<int>(threadIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
1198
    auto num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
1199
1200
1201
    auto num_warps   = num_threads / kWarpSize;
    auto warp_id     = thread_id / kWarpSize;
    auto lane_id     = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
1202

lijian6's avatar
lijian6 committed
1203
    auto nvl_rank       = rank % NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1204
1205
1206
1207
1208
    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Using two SMs, which clean the RDMA/NVL buffer respectively
    if (sm_id == 0) {
        // Barrier for RDMA
lishen's avatar
lishen committed
1209
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
1210
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Shangyan Zhou's avatar
Fix  
Shangyan Zhou committed
1211

lishen's avatar
lishen committed
1212
1213
        // Barrier for NVL
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1214

lishen's avatar
lishen committed
1215
        // Clean RDMA buffer
lijian6's avatar
lijian6 committed
1216
        auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
1217
1218
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1219

lishen's avatar
lishen committed
1220
        // Clean NVL buffer
lijian6's avatar
lijian6 committed
1221
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
1222
1223
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
lishen's avatar
lishen committed
1224

1225
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1226

lishen's avatar
lishen committed
1227
1228
1229
1230
        // Barrier again
        if (thread_id == kWarpSize)
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

Chenggang Zhao's avatar
Chenggang Zhao committed
1231
        // Barrier again
1232
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lishen's avatar
lishen committed
1233
    } else if (sm_id == 1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1234
1235
1236
1237
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1238
        EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
1239
1240
1241
1242

        // Iterate in reverse order
        if (lane_id < num_rdma_ranks and warp_id < num_channels) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1243
1244
            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1245
1246
1247

            // NOTES: `1 << 25` is a heuristic large number
            int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1248
1249
1250
            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                auto current_head =
                    __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
                if (current_head < 0) {
                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
                } else {
                    last_head = current_head;
                }
            }
        }
    } else {
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1263
1264
1265
        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
                                  rdma_rank_prefix_sum != nullptr);
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
lishen's avatar
lishen committed
1266
        constexpr int num_clean_sms = 2;
1267

lijian6's avatar
lijian6 committed
1268
        if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
lishen's avatar
lishen committed
1269
1270
            for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
                 dst_rdma_rank += num_channels * 2 - num_clean_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1271
                // Iterate in reverse order
lijian6's avatar
lijian6 committed
1272
1273
1274
1275
1276
1277
                int token_start_idx =
                    warp_id == 0
                        ? 0
                        : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
                int token_end_idx =
                    rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
1278
1279
1280
1281
1282
                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
                token_start_idx += shift, token_end_idx += shift;

                // NOTES: `1 << 25` is a heuristic large number
                int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1283
1284
1285
1286
1287
1288
1289
                for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                    auto current_head =
                        __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
                    if (current_head < 0) {
                        combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
                    } else {
                        last_head = current_head;
1290
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1291
1292
1293
1294
1295
1296
1297
                }
            }
        }
    }
}

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
1298
1299
1300
1301
1302
1303
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
1304
                   bool is_cached_dispatch, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1305
    const int  num_threads    = ::max(128, kWarpSize * num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1306
1307
1308
    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Get clean meta
lijian6's avatar
lijian6 committed
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
1319
1320
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
lishen's avatar
lishen committed
1321
    EP_HOST_ASSERT(num_channels * 2 > 2);
Chenggang Zhao's avatar
Chenggang Zhao committed
1322
1323

    // Launch kernel
lijian6's avatar
lijian6 committed
1324
    auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
Chenggang Zhao's avatar
Chenggang Zhao committed
1325
    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
lijian6's avatar
lijian6 committed
1326
1327
1328
1329
1330
1331
    LAUNCH_KERNEL_NON_COOPERATIVE(
        &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
        nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
        num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
        rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
        cpu_rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1332
1333
}

lijian6's avatar
lijian6 committed
1334
1335
1336
1337
1338
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
                             int lane_id, int hidden_int4, int num_topk,
                             int4* combined_row, float* combined_topk_weights,
                             int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1339
1340
1341
1342
    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);

    // Broadcast current heads
    // Lane `i` holds the head of rank `i` and `is_token_in_rank`
lijian6's avatar
lijian6 committed
1343
    EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
Chenggang Zhao's avatar
Chenggang Zhao committed
1344
    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
lijian6's avatar
lijian6 committed
1345
1346
1347
1348
1349
    #pragma unroll
    for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
        slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
        topk_ranks[num_topk_ranks ++] = i;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1350
1351
1352
    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);

    // Reduce data
lijian6's avatar
lijian6 committed
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
    #pragma unroll
    for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
        // Read buffers
        // TODO: maybe too many registers here
        int4 recv_value_int4[kMaxNumRanks];
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j)
            recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);

        // Reduce all-to-all results
lijian6's avatar
lijian6 committed
1363
        float values[kDtypePerInt4] = {0};
lijian6's avatar
lijian6 committed
1364
1365
1366
1367
1368
1369
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j) {
            auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
            #pragma unroll
            for (int k = 0; k < kDtypePerInt4; ++ k)
                values[k] += static_cast<float>(recv_value_dtypes[k]);
Shangyan Zhou's avatar
Shangyan Zhou committed
1370
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1371

lijian6's avatar
lijian6 committed
1372
1373
1374
1375
1376
        // Cast back to `dtype_t` and write
        int4 out_int4;
        auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
        #pragma unroll
        for (int j = 0; j < kDtypePerInt4; ++ j)
lijian6's avatar
lijian6 committed
1377
1378
            out_dtypes[j] = static_cast<dtype_t>(values[j]);
        st_na_global(combined_row + i, out_int4);
Chenggang Zhao's avatar
Chenggang Zhao committed
1379
1380
1381
1382
1383
    }

    // Reduce `topk_weights`
    if (lane_id < num_topk) {
        float value = 0;
lijian6's avatar
lijian6 committed
1384
1385
        #pragma unroll
        for (int i = 0; i < num_topk_ranks; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1386
1387
1388
1389
1390
1391
1392
1393
            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
        st_na_global(combined_topk_weights + lane_id, value);
    }

    // Return the minimum top-k rank
    return topk_ranks[0];
}

lijian6's avatar
lijian6 committed
1394
1395
1396
1397
template <bool kLowLatencyMode,
          int kNumRDMARanks,
          typename dtype_t,
          int kNumCombineForwarderWarps,
lijian6's avatar
lijian6 committed
1398
          int kNumTopkRDMARanks     = get_num_topk_rdma_ranks(kNumRDMARanks),
lijian6's avatar
lijian6 committed
1399
          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
lijian6's avatar
lijian6 committed
1400
          int kNumForwarders        = kNumRDMARanks * kNumWarpsPerForwarder,
lijian6's avatar
lijian6 committed
1401
1402
1403
          int kNumRDMAReceivers     = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) 
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
lijian6's avatar
lijian6 committed
1404
1405
1406
1407
1408
1409
1410
1411
            const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
            const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
            const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
            const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
            int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
            int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
            int num_ranks) {
lijian6's avatar
lijian6 committed
1412
1413
1414
1415
1416
1417
1418
    enum class WarpRole {
        kNVLSender,
        kNVLAndRDMAForwarder,
        kRDMAReceiver,
        kRDMACoordinator,
        kNVLCoordinator
    };
lijian6's avatar
lijian6 committed
1419
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1420
1421
1422
    __shared__ shmem_ctx_t ctx;
    shmem_wg_ctx_create(&ctx);
#endif
lijian6's avatar
lijian6 committed
1423

lijian6's avatar
lijian6 committed
1424
1425
1426
1427
1428
1429
    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

Chenggang Zhao's avatar
Chenggang Zhao committed
1430
1431
1432
1433
    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));

    // NOTES: we decouple a channel into 2 SMs
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
1434
1435
1436
1437
1438
1439
1440

    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumForwarders) {
                return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
Chenggang Zhao's avatar
Chenggang Zhao committed
1441
            } else {
lijian6's avatar
lijian6 committed
1442
                return {WarpRole::kRDMACoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1443
1444
            }
        } else {
lijian6's avatar
lijian6 committed
1445
            if(warp_id < kNumForwarders) {
lijian6's avatar
lijian6 committed
1446
                return {WarpRole::kRDMAReceiver, warp_id};
Chenggang Zhao's avatar
Chenggang Zhao committed
1447
            } else {
lijian6's avatar
lijian6 committed
1448
                return {WarpRole::kNVLCoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1449
1450
1451
            }
        }
    }();
lijian6's avatar
lijian6 committed
1452

Chenggang Zhao's avatar
Chenggang Zhao committed
1453
    auto warp_role = role_meta.first;
lijian6's avatar
lijian6 committed
1454
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
Chenggang Zhao's avatar
Chenggang Zhao committed
1455

lijian6's avatar
lijian6 committed
1456
    EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
1457
1458
    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1459
    // This approach is designed to sync multiple warps in a loop
lijian6's avatar
lijian6 committed
1460
1461
1462
1463
    constexpr int num_sync_large_iteration = 64;
    constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
    __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];   
    for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
lijian6's avatar
lijian6 committed
1464
1465
1466
1467
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
1468
    if (warp_role == WarpRole::kNVLSender) {
lijian6's avatar
lijian6 committed
1469
1470
1471
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1472

lijian6's avatar
lijian6 committed
1473
        const auto dst_nvl_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
1474
1475
1476
        // NVL layouts
        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
1477
1478
1479
1480
1481
        auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
1482

Chenggang Zhao's avatar
Chenggang Zhao committed
1483
1484
        // Get tasks for each RDMA lane
        int token_start_idx = 0, token_end_idx = 0;
lijian6's avatar
lijian6 committed
1485
1486
        if(lane_id < kNumRDMARanks) {
            int prefix_idx  = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
Chenggang Zhao's avatar
Chenggang Zhao committed
1487
            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
lijian6's avatar
lijian6 committed
1488
            token_end_idx   = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1489
        }
lijian6's avatar
lijian6 committed
1490
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1491
1492
1493

        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1494
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1495
1496

        // Iterate over all tokens and send by chunks
lijian6's avatar
lijian6 committed
1497
        while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1498
            // Exit if possible
lijian6's avatar
lijian6 committed
1499
            if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1500
1501
                break;

lijian6's avatar
lijian6 committed
1502
            // Decide next RDMA buffer to send
Chenggang Zhao's avatar
Chenggang Zhao committed
1503
            bool is_lane_ready = false;
lijian6's avatar
lijian6 committed
1504
            auto start_time    = wall_clock64();
lijian6's avatar
lijian6 committed
1505
1506

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1507
                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1508
                is_lane_ready      = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
lijian6's avatar
lijian6 committed
1509
1510
1511
                                num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;

                if(__any_sync(kFullWarpMask, is_lane_ready))
Chenggang Zhao's avatar
Chenggang Zhao committed
1512
                    break;
lijian6's avatar
lijian6 committed
1513

Chenggang Zhao's avatar
Chenggang Zhao committed
1514
                // Retry
lijian6's avatar
lijian6 committed
1515
1516
                if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1517
1518

                // Timeout check
lijian6's avatar
lijian6 committed
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
                if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
                        "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
                        channel_id,
                        rdma_rank,
                        nvl_rank,
                        dst_nvl_rank,
                        lane_id,
                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),
                        cached_channel_tail_idx,
                        token_start_idx,
                        token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1531
1532
1533
1534
1535
                    trap();
                }
            }

            // Sync token start index and count
lijian6's avatar
lijian6 committed
1536
1537
            for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
                if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1538
1539
1540
                    continue;

                // Sync token start index
lijian6's avatar
lijian6 committed
1541
1542
                auto token_idx          = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
                int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1543
1544

                // Send by chunk
lijian6's avatar
lijian6 committed
1545
                for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1546
1547
                    // Get an empty slot
                    int dst_slot_idx = 0;
lijian6's avatar
lijian6 committed
1548
1549
1550
                    if(lane_id == current_rdma_idx) {
                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
1551
                    }
lijian6's avatar
lijian6 committed
1552
                    dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
lijian6's avatar
lijian6 committed
1553
1554
1555
1556

                    // Copy data
                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                    auto shifted_x         = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
1557
                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
1558

lijian6's avatar
lijian6 committed
1559
                    // Copy source meta
lijian6's avatar
lijian6 committed
1560
1561
                    if(lane_id == 0)
                        st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1562

lijian6's avatar
lijian6 committed
1563
                    // Copy `topk_weights`
lijian6's avatar
lijian6 committed
1564
1565
1566
                    if(lane_id < num_topk)
                        st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
                                    ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1567
1568
1569
1570
1571
                }
                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
            }

            // Move queue tail
lijian6's avatar
lijian6 committed
1572
            syncwarp();
lijian6's avatar
lijian6 committed
1573
1574
1575
            if(lane_id < kNumRDMARanks and is_lane_ready) {
                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1576
1577
        }
    } else {
lijian6's avatar
lijian6 committed
1578
1579
1580
1581
        if(warp_id > kNumForwarders) {
            return;
        }

Chenggang Zhao's avatar
Chenggang Zhao committed
1582
1583
        // Combiners and coordinators
        // RDMA symmetric layout
lijian6's avatar
lijian6 committed
1584
        auto hidden_bytes = hidden_int4 * sizeof(int4);
lijian6's avatar
lijian6 committed
1585
        auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
lijian6's avatar
lijian6 committed
1586
1587
1588
        auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1589
1590

        // NVL layouts
lijian6's avatar
lijian6 committed
1591
1592
1593
1594
        void* local_nvl_buffer = buffer_ptrs[nvl_rank];
        void* nvl_buffers[NUM_MAX_NVL_PEERS];
        #pragma unroll
        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1595
            nvl_buffers[i] = buffer_ptrs[i];
lijian6's avatar
lijian6 committed
1596
1597
1598
1599
1600
        auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
Chenggang Zhao's avatar
Chenggang Zhao committed
1601
1602

        // Combiner warp synchronization
lijian6's avatar
lijian6 committed
1603
        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
1604
        __shared__ volatile bool forwarder_retired[kNumForwarders];
lijian6's avatar
lijian6 committed
1605
        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
1606
        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
lijian6's avatar
lijian6 committed
1607

Chenggang Zhao's avatar
Chenggang Zhao committed
1608
1609
1610
        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
            // Receive from NVL ranks and forward to RDMA ranks
            // NOTES: this part is using "large warps" for each RDMA ranks
lijian6's avatar
lijian6 committed
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
            const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
            const auto sub_warp_id   = target_rank % kNumWarpsPerForwarder;
            auto send_buffer         = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
            // auto sync_large_warp     = [=]() {
            //     if(kNumWarpsPerForwarder == 1) {
            //         syncwarp();
            //     } else {
            //         // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
            //         // __syncthreads();
            //         syncwarp();
            //     }
            // };
            auto sync_large_warp = [=](const int iter, const int mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1624
                if (kNumWarpsPerForwarder == 1) {
lijian6's avatar
lijian6 committed
1625
                    syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1626
                } else {
lijian6's avatar
lijian6 committed
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
                        // LDS index to store for sync
                        int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        //reset index in the LDS to avoid race condition due to warp scheduling
                        int reset_idx =         dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        auto start_time = wall_clock64();
                        if (lane_id == 0){
                            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
                        }
                        syncwarp();
                        //The while(...) loop polls the counter until all warps have arrived
                        if (lane_id == 0){
                            while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
                                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                                    printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
                                    trap();
                                }
lijian6's avatar
lijian6 committed
1643
1644
                            }
                        }
lijian6's avatar
lijian6 committed
1645
1646
1647
1648
1649
                        syncwarp();
                        if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
                            sync_large_warp_counters[reset_idx] = 0;
                        }
                        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1650
1651
                }
            };
lijian6's avatar
lijian6 committed
1652
            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
1653

lijian6's avatar
lijian6 committed
1654
1655
            // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
lijian6's avatar
lijian6 committed
1656
            nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
lijian6's avatar
lijian6 committed
1657
            nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
1658
1659
1660
1661
            nvl_channel_head.advance(dst_rdma_rank);
            nvl_channel_tail.advance(dst_rdma_rank);

            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1662
1663
1664
1665
1666
            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
            // sync_forwarder_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1667
1668
1669

            // Get count and cached head
            int cached_nvl_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1670
1671
            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1672
1673
1674
1675
1676
            num_tokens_to_combine -= num_tokens_prefix;
            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;

            // Iterate over all tokens and combine by chunks
lijian6's avatar
lijian6 committed
1677
            for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1678
                // Check destination queue emptiness, or wait a buffer to be released
lijian6's avatar
lijian6 committed
1679
                auto token_end_idx      = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
Chenggang Zhao's avatar
Chenggang Zhao committed
1680
                auto num_chunked_tokens = token_end_idx - token_start_idx;
lijian6's avatar
lijian6 committed
1681
                auto start_time         = wall_clock64();
lijian6's avatar
lijian6 committed
1682
1683
1684
1685
1686
1687
                while(sub_warp_id == 0 and lane_id == 0) {
                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
                    // Here, `token_start_idx` is the actual tail
                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));

                    if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
1688
1689
1690
                        break;

                    // Timeout check
lijian6's avatar
lijian6 committed
1691
1692
1693
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
                                channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
1694
1695
1696
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1697
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1698
                sync_large_warp(token_start_idx, 0);
lijian6's avatar
lijian6 committed
1699

Chenggang Zhao's avatar
Chenggang Zhao committed
1700
                // Combine and write to the RDMA buffer
lijian6's avatar
lijian6 committed
1701
                for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1702
                    // Read expected head
lijian6's avatar
lijian6 committed
1703
                    EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1704
                    int expected_head = -1;
lijian6's avatar
lijian6 committed
1705
1706
                    if(lane_id < NUM_MAX_NVL_PEERS)
                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1707
1708

                    // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1709
                    start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1710
1711
                    while(cached_nvl_channel_tail_idx <= expected_head) {
                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1712
1713

                        // Timeout check
lijian6's avatar
lijian6 committed
1714
1715
1716
                        if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
                            printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
                                    channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1717
1718
1719
1720
1721
                            trap();
                        }
                    }

                    // Combine current token
lijian6's avatar
lijian6 committed
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
                    auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
                    combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
                                                                                    expected_head, lane_id,
                                                                                    hidden_int4, num_topk,
                                                                                    reinterpret_cast<int4*>(shifted),
                                                                                    reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
                                                                                    num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1732
1733

                    // Update head
lijian6's avatar
lijian6 committed
1734
1735
1736
1737
                    if(lane_id < NUM_MAX_NVL_PEERS) {
                        expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
                                        : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1738
                }
lijian6's avatar
lijian6 committed
1739
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1740
                sync_large_warp(token_start_idx, 1);
lijian6's avatar
lijian6 committed
1741

Chenggang Zhao's avatar
Chenggang Zhao committed
1742
                // Issue RDMA send
lijian6's avatar
lijian6 committed
1743
1744
                if(sub_warp_id == kNumWarpsPerForwarder - 1) {
                    if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1745
                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
1746
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1747
1748
1749
1750
                        shmem_ctx_schar_put_nbi_warp(ctx,
#else
                        shmemx_int8_put_nbi_warp(
#endif
lijian6's avatar
lijian6 committed
1751
1752
1753
1754
1755
1756
                            rdma_channel_data.recv_buffer(rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            rdma_channel_data.send_buffer(dst_rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            num_chunked_tokens * num_bytes_per_rdma_token,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1757
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1758
1759
1760
1761
                        shmem_ctx_quiet(ctx);                
#else
                        shmem_fence();
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1762
1763
1764
1765
1766
                    } else {
                        memory_fence();
                    }

                    // Write new RDMA tail
lijian6's avatar
lijian6 committed
1767
                    syncwarp();
lijian6's avatar
lijian6 committed
1768
                    if(lane_id == 0) {
lijian6's avatar
lijian6 committed
1769
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1770
1771
1772
1773
1774
                        shmem_ctx_ulong_atomic_add(ctx,
#else
                        shmem_signal_op_add(
#endif
                            rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
lijian6's avatar
lijian6 committed
1775
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1776
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1777
1778
1779
1780
                }
            }

            // Retired
lijian6's avatar
lijian6 committed
1781
            syncwarp();
lijian6's avatar
lijian6 committed
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
            if(lane_id == 0) {
                forwarder_retired[target_rank] = true;
            }
        } else if (warp_role == WarpRole::kRDMACoordinator) {
            // Coordinator
            // Sync shared memory status
            // sync_forwarder_smem();
            __syncthreads();
            constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

            int last_nvl_head[kNumRDMARanks] = {0};
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
                // Retired
                if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
                    break;

                {
                    // Find minimum head for NVL ranks
                    #pragma unroll
                    for(int i = 0; i < kNumRDMARanks; ++i) {
                        int min_head = std::numeric_limits<int>::max();
                        #pragma unroll
                        for(int j = 0; j < num_warps_per_rdma_rank; ++j)
                            if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);

                        if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
                        }
                    }
                }

                // Nanosleep and let other warps work
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
            }
        } else if(warp_role == WarpRole::kRDMAReceiver) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1821
1822
            // Receive from RDMA ranks and write to the output tensor
            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1823
            EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lijian6's avatar
lijian6 committed
1824
1825
1826
1827
            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1828
1829
1830

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1831
            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1832
1833
1834

            // Iterate over all tokens and combine
            int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1835
            for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1836
                // Read expected head
lijian6's avatar
lijian6 committed
1837
                EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1838
                int expected_head = -1;
lijian6's avatar
lijian6 committed
1839
1840
1841
1842
                if(lane_id < kNumRDMARanks) {
                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
                    (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
                                        : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1843
1844
1845
                }

                // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1846
                auto start_time = wall_clock64();
Chenggang Zhao's avatar
Chenggang Zhao committed
1847
                while (cached_channel_tail_idx <= expected_head) {
lijian6's avatar
lijian6 committed
1848
                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
Chenggang Zhao's avatar
Chenggang Zhao committed
1849
1850

                    // Timeout check
lijian6's avatar
lijian6 committed
1851
1852
1853
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
                                channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1854
1855
1856
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1857
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1858
1859

                // Combine current token
lijian6's avatar
lijian6 committed
1860
1861
1862
1863
1864
1865
1866
1867
                auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
                combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
                                                                            expected_head, lane_id,
                                                                            hidden_int4, num_topk,
                                                                            combined_x + token_idx * hidden_int4,
                                                                            combined_topk_weights + token_idx * num_topk,
                                                                            num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1868
1869
1870
            }

            // Retired
lijian6's avatar
lijian6 committed
1871
            syncwarp();
lijian6's avatar
lijian6 committed
1872
1873
1874
1875
            if(lane_id == 0) {
                rdma_receiver_retired[target_rank] = true;
            }
        } else if(warp_role == WarpRole::kNVLCoordinator) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1876
1877
            // Coordinator
            // Sync shared memory status
lijian6's avatar
lijian6 committed
1878
1879
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1880
1881
            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1882
            int last_rdma_head               = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
1883
            int last_nvl_head[kNumRDMARanks] = {0};
lijian6's avatar
lijian6 committed
1884
1885
            int dst_rdma_rank                = lane_id < kNumRDMARanks ? lane_id : 0;
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
lijian6's avatar
lijian6 committed
1886
1887
1888
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1889
                // Retired
lijian6's avatar
lijian6 committed
1890
                if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
Chenggang Zhao's avatar
Chenggang Zhao committed
1891
1892
1893
                    break;

                // Find minimum head for RDMA ranks
lijian6's avatar
lijian6 committed
1894
                {
Chenggang Zhao's avatar
Chenggang Zhao committed
1895
                    int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1896
1897
1898
    #pragma unroll
                    for(int i = 0; i < kNumRDMAReceivers; ++i)
                        if(not rdma_receiver_retired[i])
lijian6's avatar
lijian6 committed
1899
                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
lijian6's avatar
lijian6 committed
1900
1901

                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lijian6's avatar
lijian6 committed
1902
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1903
1904
1905
1906
1907
                        shmem_ctx_ulong_atomic_add(ctx,
#else
                        shmem_signal_op_add(
#endif
                            rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
lijian6's avatar
lijian6 committed
1908
1909
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

1910
1911
                        last_rdma_head = min_head;
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1912
1913
1914
                }

                // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1915
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1916
1917
1918
            }
        }
    }
lijian6's avatar
lijian6 committed
1919
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1920
1921
    shmem_wg_ctx_destroy(&ctx);
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1922
1923
}

lijian6's avatar
lijian6 committed
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1934
    constexpr int kNumCombineForwarderWarps = 8;
lijian6's avatar
lijian6 committed
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953

#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                                        \
    {                                                                                              \
        auto combine_func =                                                                        \
            low_latency_mode                                                                       \
                ? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>           \
                : combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>;         \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights,       \
            is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights,            \
            reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1),        \
            combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,           \
            num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr,                    \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

lijian6's avatar
lijian6 committed
1954
    int num_rdma_ranks           = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1955
    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
lijian6's avatar
lijian6 committed
1956
1957
    int num_forwarder_warps      = num_rdma_ranks * num_warps_per_forwarder;
    EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
lijian6's avatar
lijian6 committed
1958
    EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
1959
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
lijian6's avatar
lijian6 committed
1960
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
lijian6's avatar
lijian6 committed
1961
    EP_HOST_ASSERT(type == HIP_R_16BF);
Chenggang Zhao's avatar
Chenggang Zhao committed
1962

lijian6's avatar
lijian6 committed
1963
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1964
1965
1966
    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
1967

Chenggang Zhao's avatar
Chenggang Zhao committed
1968
1969
1970
} // namespace internode

} // namespace deep_ep
lijian6's avatar
lijian6 committed
1971

lijian6's avatar
lijian6 committed
1972
1973
1974
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
lijian6's avatar
lijian6 committed
1975
1976

#endif