internode.cu 106 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
3
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include "launch.cuh"
#include "utils.cuh"
6
#include "shmem_wrapper.cuh"
lijian6's avatar
lijian6 committed
7
8
9
10

#ifndef DISABLE_ROCSHMEM

// TODO: fix unroll warnings
lijian6's avatar
lijian6 committed
11
12
13
14
15
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
Chenggang Zhao's avatar
Chenggang Zhao committed
16
17
18
19
20

namespace deep_ep {

namespace internode {

21
extern shmem_team_t cpu_rdma_team;
Chenggang Zhao's avatar
Chenggang Zhao committed
22
23

struct SourceMeta {
lishen's avatar
lishen committed
24
    int src_rdma_rank, is_token_in_nvl_rank_bits;   // sizeof(SourceMeta) = 8
Chenggang Zhao's avatar
Chenggang Zhao committed
25
26
27
28
29
30

    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");

    __forceinline__ SourceMeta() = default;

    // TODO: faster encoding
lijian6's avatar
lijian6 committed
31
32
    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
        src_rdma_rank             = rdma_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
33
        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
lijian6's avatar
lijian6 committed
34
35
#pragma unroll
        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
36
37
38
39
40
41
42
43
44
45
46
47
            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
    }

    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
        return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
    }
};

int get_source_meta_bytes() {
    return sizeof(SourceMeta);
}

lijian6's avatar
lijian6 committed
48
49
50
51
52
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
                                                                     int num_scales,
                                                                     int num_topk_idx,
                                                                     int num_topk_weights) {
    return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
53
54
                                  num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
                                  num_topk_weights * sizeof(float), sizeof(int4)));
Chenggang Zhao's avatar
Chenggang Zhao committed
55
56
}

lijian6's avatar
lijian6 committed
57
58
59
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                    int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
lijian6's avatar
lijian6 committed
60
    // Return `int32_t` offset and count to clean
lijian6's avatar
lijian6 committed
61
    return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
62
63
           num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int),
           (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
Chenggang Zhao's avatar
Chenggang Zhao committed
64
}
lijian6's avatar
lijian6 committed
65
66
67
68
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
                   int num_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
69
    // Return `int32_t` offset and to clean
lijian6's avatar
lijian6 committed
70
71
    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
                              "Invalid size of `SourceMeta`");
Chenggang Zhao's avatar
Chenggang Zhao committed
72
    return {
lijian6's avatar
lijian6 committed
73
        (num_nvl_recv_buffer_tokens *
74
75
76
        (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
        num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
        num_nvl_ranks * num_sms) / sizeof(int),
lijian6's avatar
lijian6 committed
77
        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
Chenggang Zhao's avatar
Chenggang Zhao committed
78
79
80
81
    };
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
82
83
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
                                                       const int nvl_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
84
85
86
87
    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
88
__forceinline__ __device__ void
lijian6's avatar
lijian6 committed
89
dushmem_barrier_with_same_gpu_idx(const shmem_team_t &rdma_team) {
lijian6's avatar
lijian6 committed
90
91
    // NOTE: shmem_device_barrier_all() might be an issue as
    // it doesn't follow OpenSHMEM specification on ROCm
92
    kLowLatencyMode ? shmem_barrier(rdma_team) : shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
93
94
95
96
}

template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
lijian6's avatar
lijian6 committed
97
98
99
100
101
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
                int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
Chenggang Zhao's avatar
Chenggang Zhao committed
102
                const int nvl_clean_offset, const int nvl_num_int_clean,
lijian6's avatar
lijian6 committed
103
104
105
                int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
                int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
                void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
106
                const shmem_team_t rdma_team) {
lijian6's avatar
lijian6 committed
107
108
109
110
    auto sm_id     = static_cast<int>(blockIdx.x);
    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
         lane_id     = get_lane_id();
    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
111
112

    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
113
114
    auto num_rdma_experts = num_experts / kNumRDMARanks,
         num_nvl_experts  = num_rdma_experts / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
115
116
117

    if (sm_id == 0) {
        // Communication with others
lijian6's avatar
lijian6 committed
118
119
        // Global barrier: the first warp do intra-node sync, the second warp do internode sync
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
120
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
121

lijian6's avatar
lijian6 committed
122
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
123

Chenggang Zhao's avatar
Chenggang Zhao committed
124
        // Send numbers of tokens per rank/expert to RDMA ranks
lijian6's avatar
lijian6 committed
125
126
127
        auto rdma_buffer_ptr_int        = reinterpret_cast<int *>(rdma_buffer_ptr);
        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
            rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
Chenggang Zhao's avatar
Chenggang Zhao committed
128
129

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
130
131
        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
                                  rdma_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
132
133
134
135
136
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;

        // Copy to send buffer
        for (int i = thread_id; i < num_ranks; i += num_threads)
lijian6's avatar
lijian6 committed
137
138
            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
                num_tokens_per_rank[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
139
        for (int i = thread_id; i < num_experts; i += num_threads)
lijian6's avatar
lijian6 committed
140
141
142
            rdma_recv_num_tokens_mixed.send_buffer(
                i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
                num_tokens_per_expert[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
143
        if (thread_id < kNumRDMARanks)
lijian6's avatar
lijian6 committed
144
145
146
            rdma_recv_num_tokens_mixed.send_buffer(
                thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
                num_tokens_per_rdma_rank[thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
147
148
149
150
151
        __syncthreads();

        // Issue send
        // TODO: more light fence or barrier or signaling
        // TODO: overlap EP barrier and NVL cleaning
lishen's avatar
lishen committed
152
153
154
        for (int i = warp_id; i < kNumRDMARanks; i += num_warps) {
            if (i != rdma_rank) {
                shmemx_int_put_nbi_warp(
lijian6's avatar
lijian6 committed
155
                rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
lishen's avatar
lishen committed
156
                rdma_recv_num_tokens_mixed.send_buffer(i),
lijian6's avatar
lijian6 committed
157
                NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
lishen's avatar
lishen committed
158
159
160
161
162
163
164
165
166
167
                translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank));
            } else {
                UNROLLED_WARP_COPY(1,
                                   lane_id,
                                   NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
                                   rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
                                   rdma_recv_num_tokens_mixed.send_buffer(i),
                                   ld_volatile_global,
                                   st_na_global);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
168
        }
alpha-baby's avatar
alpha-baby committed
169
        __syncthreads();
lishen's avatar
lishen committed
170

Chenggang Zhao's avatar
Chenggang Zhao committed
171
        if (thread_id == 0)
lijian6's avatar
lijian6 committed
172
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
lijian6's avatar
lijian6 committed
173

Chenggang Zhao's avatar
Chenggang Zhao committed
174
175
176
177
178
        __syncthreads();

        // NVL buffers
        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
179
180
181
182
183
184
185
186
187
188
        auto nvl_reduced_num_tokens_per_expert =
            Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
        auto nvl_send_num_tokens_per_rank =
            AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_send_num_tokens_per_expert =
            AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_rank =
            AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_expert =
            AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
Chenggang Zhao's avatar
Chenggang Zhao committed
189
190

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
191
192
193
194
195
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
                                      nvl_send_num_tokens_per_rank.total_bytes +
                                      nvl_send_num_tokens_per_expert.total_bytes <=
                                  nvl_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
196
197
198
199
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;

        // Reduce number of tokens per expert into the NVL send buffer
lijian6's avatar
lijian6 committed
200
        // TODO: may use DUSHMEM reduction
Chenggang Zhao's avatar
Chenggang Zhao committed
201
202
203
        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
        if (thread_id < num_rdma_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
204
205
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
206
207
208
209
210
211
212
213
                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
            nvl_reduced_num_tokens_per_expert[thread_id] = sum;
        }
        __syncthreads();

        // Reduce RDMA received tokens
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
214
215
216
217
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i) {
                sum +=
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
Chenggang Zhao's avatar
Chenggang Zhao committed
218
219
                recv_rdma_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
220
221
            while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
222
223
224
225
226
            *moe_recv_rdma_counter_mapped = sum;
        }

        // Send numbers of tokens per rank/expert to NVL ranks
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
227
228
229
230
231
232
233
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
            for (int i = 0; i < num_nvl_experts; ++i)
                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
                    nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
234
        }
235
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
236

lijian6's avatar
lijian6 committed
237
        // Reduce number of tokens per rank/expert
Chenggang Zhao's avatar
Chenggang Zhao committed
238
239
240
        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
241
            for (int i = 0; i < num_ranks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
242
243
244
245
                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
                recv_gbl_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
246
247
            while (ld_volatile_global(moe_recv_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
248
249
250
251
            *moe_recv_counter_mapped = sum;
        }
        if (thread_id < num_nvl_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
252
253
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
254
255
                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
lijian6's avatar
lijian6 committed
256
257
            while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
258
259
260
261
            moe_recv_expert_counter_mapped[thread_id] = sum;
        }

        // Finally barrier
lijian6's avatar
lijian6 committed
262
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
263
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
lijian6's avatar
lijian6 committed
264

265
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
266
267
268
269
270
    } else {
        // Calculate meta data
        int dst_rdma_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
271
272
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
273
274
275

            // Iterate over tokens
            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
lijian6's avatar
lijian6 committed
276
277
278
279
280
281
282
283
284
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
                                          "Invalid number of NVL peers");
                auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
                    is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
                auto is_token_in_rank_values =
                    reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
285
286
287
288
289
290
                    per_nvl_rank_count[j] += is_token_in_rank_values[j];
                total_count += (is_token_in_rank_uint64 != 0);
            }

            // Warp reduce
            total_count = warp_reduce_sum(total_count);
lijian6's avatar
lijian6 committed
291
292
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
293
294
295
                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);

            // Write into channel matrix
lijian6's avatar
lijian6 committed
296
297
298
299
300
301
            if (lane_id == 0) {
#pragma unroll
                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
                                                  num_channels +
                                              channel_id] = per_nvl_rank_count[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
302
303
304
305
306
307
                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
            }
        }

        // Calculate prefix sum
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
308
        if (thread_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
309
            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
lijian6's avatar
lijian6 committed
310
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
311
312
313
                prefix_row[i] += prefix_row[i - 1];
        }

lijian6's avatar
lijian6 committed
314
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
315
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
316
317
318
            auto prefix_row = gbl_channel_prefix_matrix +
                              (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
319
320
321
322
323
                prefix_row[i] += prefix_row[i - 1];
        }
    }
}

lijian6's avatar
lijian6 committed
324
325
326
327
328
329
330
331
332
333
334
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
335
                     bool low_latency_mode) {
lijian6's avatar
lijian6 committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                \
    {                                                                                              \
        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks>       \
                                                     : notify_dispatch<false, num_rdma_ranks>;     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks,   \
            num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert,         \
            moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens,             \
            num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second,         \
            nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix,               \
            recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,        \
            rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team);               \
    }                                                                                              \
    break

    constexpr int kNumThreads    = 256;
    const auto    num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
353
354

    // Get clean meta
lijian6's avatar
lijian6 committed
355
356
357
358
359
360
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
lishen's avatar
lishen committed
361
362
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
363
364
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
lishen's avatar
lishen committed
365
366
367
    // add assert origin kernel
    EP_HOST_ASSERT(num_rdma_ranks <= kNumThreads);
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kNumThreads, "Assert NUM_MAX_NVL_PEERS <= kNumThreads");
Chenggang Zhao's avatar
Chenggang Zhao committed
368
369
370
371
372
373
374
375
376
377
378
379

    // Launch kernel
    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}

lijian6's avatar
lijian6 committed
380
381
382
383

template <bool kLowLatencyMode,
          int kNumRDMARanks,
          bool kCachedMode,
lijian6's avatar
lijian6 committed
384
385
          int kNumDispatchRDMASenderWarps,
          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
lijian6's avatar
lijian6 committed
386
387
388
389
390
391
392
393
394
395
396
397
398
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
        SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
        const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
        int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
        int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
        const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
        const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
        int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
        int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
        int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
        int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
        int num_ranks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
399
    enum class WarpRole {
lijian6's avatar
lijian6 committed
400
401
402
403
404
        kRDMASender,            // 从x写入到RDMA发送缓存
        kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
        kRDMAAndNVLForwarder,   // 从RDMA接收缓存转写到ipc nvl缓存
        kForwarderCoordinator,  // 向远端RDMA确认接收
        kNVLReceivers           // 从nvl缓存写入到recv_x
Chenggang Zhao's avatar
Chenggang Zhao committed
405
    };
lijian6's avatar
lijian6 committed
406
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
407
408
409
    __shared__ shmem_ctx_t ctx;
    shmem_wg_ctx_create(&ctx);
#endif
lijian6's avatar
lijian6 committed
410
411
412

    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
lijian6's avatar
lijian6 committed
413
414
415
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
416
417
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

lijian6's avatar
lijian6 committed
418
419
420
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
    EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);

Chenggang Zhao's avatar
Chenggang Zhao committed
421
    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
lijian6's avatar
lijian6 committed
422
423
424
425
426
427
428
429
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASender, -1};
            } else if(warp_id == kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASenderCoordinator, -1};
            }
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            if(warp_id < NUM_MAX_NVL_PEERS) {
Chenggang Zhao's avatar
Chenggang Zhao committed
430
431
432
433
434
                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
            } else {
                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
            }
        } else {
lijian6's avatar
lijian6 committed
435
            return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
Chenggang Zhao's avatar
Chenggang Zhao committed
436
437
        }
    }();
lijian6's avatar
lijian6 committed
438

lijian6's avatar
lijian6 committed
439
    auto warp_role = role_meta.first;
Chenggang Zhao's avatar
Chenggang Zhao committed
440
441
442
    auto target_rank = role_meta.second; // Not applicable for RDMA senders

    // RDMA symmetric layout
lijian6's avatar
lijian6 committed
443
444
445
446
447
448
    auto hidden_bytes             = hidden_int4 * sizeof(int4);
    auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
    auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
449
450

    // NVL buffer layouts
lijian6's avatar
lijian6 committed
451
    // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
Chenggang Zhao's avatar
Chenggang Zhao committed
452
    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
lijian6's avatar
lijian6 committed
453
    int rs_wr_rank = 0, ws_rr_rank = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
454
    if (warp_role == WarpRole::kRDMAAndNVLForwarder)
lijian6's avatar
lijian6 committed
455
        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
456
    if (warp_role == WarpRole::kNVLReceivers)
lijian6's avatar
lijian6 committed
457
        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
458
459

    // Allocate buffers
lijian6's avatar
lijian6 committed
460
461
462
463
464
465
466
467
468
    auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
    auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
469
470

    // RDMA sender warp synchronization
lijian6's avatar
lijian6 committed
471
472
473
    __shared__ volatile int rdma_send_next_token_idx;
    __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
    __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
474

lijian6's avatar
lijian6 committed
475
476
    // NVL and RDMA coordinate Forward warp synchronization
    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
477
    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
lijian6's avatar
lijian6 committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498

    // Place the main logic of your kernel here, using the parameters above.
    if(warp_role == WarpRole::kRDMASender) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
        然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
        最后,它复制相关的数据到对称发送缓冲区。

        kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
        期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
        同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
        */
        // 获取任务范围
        int token_start_idx, token_end_idx;
        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

        // 清理共享内存
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
        if(warp_id == 0 && lane_id == 0) {
            rdma_send_next_token_idx = token_start_idx;
lijian6's avatar
lijian6 committed
499
        }
lijian6's avatar
lijian6 committed
500
501
502
        if(warp_id == 0 && lane_id < kNumRDMARanks) {
            rdma_send_channel_tail[lane_id]      = 0;
            rdma_send_channel_next_tail[lane_id] = 0;
lijian6's avatar
lijian6 committed
503
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
504

lijian6's avatar
lijian6 committed
505
506
507
508
509
510
        // 发送本通道中的令牌数量,通过 `-value - 1` 表示
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
        // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
        // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
        for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
            auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
511
            if (lane_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
512
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
513
            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
514
                dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
515
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
516
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
517
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
lijian6's avatar
lijian6 committed
518
                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
519
            }
lijian6's avatar
lijian6 committed
520

lijian6's avatar
lijian6 committed
521
522
            syncwarp();
            if (dst_rdma_rank != rdma_rank) {
lijian6's avatar
lijian6 committed
523
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
524
525
526
527
528
                shmem_ctx_int_put_nbi_warp(ctx, 
#else
                shmemx_int_put_nbi_warp(
#endif
                rdma_channel_meta.recv_buffer(rdma_rank),
lijian6's avatar
lijian6 committed
529
530
                rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
                translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
531
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
532
        }
533

lijian6's avatar
lijian6 committed
534
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
535
536
537
538
539
        shmem_ctx_quiet(ctx);                
#else
        shmem_fence();
#endif

lijian6's avatar
lijian6 committed
540
541
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
542

lijian6's avatar
lijian6 committed
543
        // 遍历令牌并复制到缓冲区
Chenggang Zhao's avatar
Chenggang Zhao committed
544
        int64_t token_idx;
lijian6's avatar
lijian6 committed
545
546
547
548
        int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
        for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
            // 读取RDMA秩的存在性
Chenggang Zhao's avatar
Chenggang Zhao committed
549
            uint64_t is_token_in_rank_uint64 = 0;
lijian6's avatar
lijian6 committed
550
551
552
            if(lane_id < kNumRDMARanks) {
                is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
553

lijian6's avatar
lijian6 committed
554
555
556
557
            // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
            while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
                // 等待
            }
lijian6's avatar
lijian6 committed
558
            syncwarp();
559

lijian6's avatar
lijian6 committed
560
            // 获取下一个尾部位置
lijian6's avatar
lijian6 committed
561
            int rdma_tail_idx = -1;
lijian6's avatar
lijian6 committed
562
            if(is_token_in_rank_uint64 != 0) {
lijian6's avatar
lijian6 committed
563
                rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
lijian6's avatar
lijian6 committed
564
565
566
567
568

                // 与kForwarderCoordinator相互配合,调节发送数据的频率
                while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
                    cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
569
            }
lijian6's avatar
lijian6 committed
570
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
571

lijian6's avatar
lijian6 committed
572
573
            // 存储RDMA头部以供合并
            if(lane_id < kNumRDMARanks && !kCachedMode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
574
                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
lijian6's avatar
lijian6 committed
575
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
576

lijian6's avatar
lijian6 committed
577
578
579
580
            // 更新最后一个令牌尾部
            if(last_rdma_tail_idx >= 0) {
                st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
            }
lijian6's avatar
lijian6 committed
581
582
            last_rdma_tail_idx = rdma_tail_idx;

lijian6's avatar
lijian6 committed
583
584
585
586
            // 释放顺序锁
            if(lane_id == 0) {
                rdma_send_next_token_idx += 1;
            }
lijian6's avatar
lijian6 committed
587

lijian6's avatar
lijian6 committed
588
            // 广播尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
589
            SourceMeta src_meta;
lijian6's avatar
lijian6 committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
            void* dst_send_buffers[kNumTopkRDMARanks];
            /*
            该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
            */
            #pragma unroll
            for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
                // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
                if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
                    // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
                    //  计算slot_idx在接收缓冲区中的位置
                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;

                    // 存储当前RDMA秩到topk_ranks数组中
                    topk_ranks[num_topk_ranks] = i;

                    // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
lijian6's avatar
lijian6 committed
607
                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
lijian6's avatar
lijian6 committed
608
609
610
611
                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);

                    // 如果当前lane_id等于num_topk_ranks,则更新src_meta
                    if(lane_id == num_topk_ranks) {
lijian6's avatar
lijian6 committed
612
                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
lijian6's avatar
lijian6 committed
613
614
615
616
617
                    }

                    // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
                    // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
                    dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
lijian6's avatar
lijian6 committed
618
                }
lijian6's avatar
lijian6 committed
619
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
620
621
            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);

lishen's avatar
lishen committed
622
623
624
625
626
627
628
629
630
631
            //////////////// 复制数据到发送缓冲区 ////////////////
            // 复制源元数据到对称发送缓冲区
            if(lane_id < num_topk_ranks) {
                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
            }

            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
            }

lijian6's avatar
lijian6 committed
632
633
634
635
636
            // 复制 `x` 到对称发送缓冲区
            auto st_broadcast = [=](const int key, const int4& value) {
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
637
            };
lijian6's avatar
lijian6 committed
638
639
640
641
            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
            }
lijian6's avatar
lijian6 committed
642

lijian6's avatar
lijian6 committed
643
644
            // 复制 `x_scales` 到对称发送缓冲区
            for(int i = lane_id; i < num_scales; i += kWarpSize) {
lijian6's avatar
lijian6 committed
645
                auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
lijian6's avatar
lijian6 committed
646
647
648
649
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
                }
            }
lishen's avatar
lishen committed
650

lijian6's avatar
lijian6 committed
651
652
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
lijian6's avatar
lijian6 committed
653
            }
654

lijian6's avatar
lijian6 committed
655
656
            // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
            for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
Chenggang Zhao's avatar
Chenggang Zhao committed
657
                auto rank_idx = i / num_topk, copy_idx = i % num_topk;
lijian6's avatar
lijian6 committed
658
                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
659
                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
lijian6's avatar
lijian6 committed
660
661
                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
662
            }
lijian6's avatar
lijian6 committed
663
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
664

lijian6's avatar
lijian6 committed
665
666
667
668
669
670
        // 结尾部分
        // 获取顺序锁
        while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
            // 等待
        }

lijian6's avatar
lijian6 committed
671
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
672

lijian6's avatar
lijian6 committed
673
674
675
676
        // 更新最后一个令牌尾部
        if(last_rdma_tail_idx >= 0) {
            st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
677

lijian6's avatar
lijian6 committed
678
679
680
681
682
683
684
685
686
687
688
689
        // 释放顺序锁
        if(lane_id == 0) {
            rdma_send_next_token_idx += 1;
        }
    } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
        如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
        最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。

        kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
lijian6's avatar
lijian6 committed
690
        dushmem内存一致性(dushmem_fence)和原子操作(dushmemx_signal_op),减少硬同步,提升整体效率。
lijian6's avatar
lijian6 committed
691
692
693
694
695
696
        */
        if(warp_id > kNumDispatchRDMASenderWarps) {
            return;
        }
        // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
697

lijian6's avatar
lijian6 committed
698
699
700
        // 同步共享内存,确保所有线程在继续之前都达到了这一点
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
701

lijian6's avatar
lijian6 committed
702
        // 计算当前通道需要发送的令牌数
Chenggang Zhao's avatar
Chenggang Zhao committed
703
        int num_tokens_to_send = 0;
lijian6's avatar
lijian6 committed
704
        if(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
705
            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
lijian6's avatar
lijian6 committed
706
707
            if(channel_id > 0)
                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
708
709
        }

lijian6's avatar
lijian6 committed
710
        // 记录上次发出的尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
711
        int last_issued_tail = 0;
lijian6's avatar
lijian6 committed
712
713
714
715
716
717
718
        // 当有任何RDMA秩需要发送令牌时,继续循环
        while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
            for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
                // 计算目标RDMA秩
                int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;

                // 获取同步后的需要发送的令牌数
lijian6's avatar
lijian6 committed
719
                synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
720

lijian6's avatar
lijian6 committed
721
722
723
724
                if(synced_num_tokens_to_send == 0)
                    continue; // 如果没有令牌需要发送,则跳过

                // 读取进度
lijian6's avatar
lijian6 committed
725
                auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
lijian6's avatar
lijian6 committed
726
727
728
729
730
                auto processed_tail          = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
                auto num_tokens_processed    = processed_tail - synced_last_issued_tail;

                // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
                if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
731
732
                    continue;

lijian6's avatar
lijian6 committed
733
734
735
736
737
738
                // 计算本次需要发出的令牌数
                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);

                // 发出RDMA发送请求
                if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
739
                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
740
                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
741
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
742
743
744
745
                    shmem_ctx_schar_put_nbi_warp(ctx,
#else
                    shmemx_int8_put_nbi_warp(
#endif
lijian6's avatar
lijian6 committed
746
747
748
749
750
751
                        rdma_channel_data.recv_buffer(rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        rdma_channel_data.send_buffer(dst_rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        num_bytes_per_rdma_token * num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
752
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
753
754
755
756
                    shmem_ctx_quiet(ctx);                
#else
                    shmem_fence();
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
757
                } else {
lijian6's avatar
lijian6 committed
758
                    // 对于本地RDMA秩,使用较轻的内存屏障
Chenggang Zhao's avatar
Chenggang Zhao committed
759
760
761
                    memory_fence();
                }

lijian6's avatar
lijian6 committed
762
                // 更新尾部位置
lijian6's avatar
lijian6 committed
763
                syncwarp();
lijian6's avatar
lijian6 committed
764
                if(lane_id == dst_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
765
766
                    last_issued_tail += num_tokens_to_issue;
                    num_tokens_to_send -= num_tokens_to_issue;
lijian6's avatar
lijian6 committed
767
                    // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
lijian6's avatar
lijian6 committed
768
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
769
770
771
772
773
                    shmem_ctx_ulong_atomic_add(ctx,
#else
                    shmem_signal_op_add(
#endif
                        rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
lijian6's avatar
lijian6 committed
774
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
775
776
                }
            }
lijian6's avatar
lijian6 committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
        } // while(__any(num_tokens_to_send > 0))
    } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
        它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
        接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
        然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
        最后,它同步头部和尾部索引,并标记通道为退役状态。
        */
        // RDMA消费者和NVL生产者
        const auto dst_nvl_rank          = target_rank;                                       // 目标NVL秩
        const auto dst_rank              = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;      // 目标秩
        const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);              // 目标秩专家开始
        const auto dst_rank_expert_end   = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束

        // 等待计数器到达
Chenggang Zhao's avatar
Chenggang Zhao committed
793
        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
lijian6's avatar
lijian6 committed
794
795
        EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
796
797
798
799
800
801
802
803
804
        if(lane_id < kNumRDMARanks) {
            while(true) {
                // 对应于kRDMASender中的数据写入
                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);                     // 是nvl节点的起始地址
                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);            // 本rdma节点的起始地址
                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);        // 本节点的结束地址
                if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
                    // 通知NVL秩
Chenggang Zhao's avatar
Chenggang Zhao committed
805
                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
lijian6's avatar
lijian6 committed
806
807
808
                    EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);

                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
809
810
                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);

lijian6's avatar
lijian6 committed
811
812
                    // 保存从RDMA通道接收的令牌计数
                    src_rdma_channel_prefix = -meta_2 - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
813
                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;
lijian6's avatar
lijian6 committed
814
815
816
817
818
                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
                    if(!kCachedMode)
                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;

                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
Chenggang Zhao's avatar
Chenggang Zhao committed
819
820
821
822
                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
                    break;
                }

lijian6's avatar
lijian6 committed
823
824
825
826
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
                           channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
Chenggang Zhao's avatar
Chenggang Zhao committed
827
828
829
830
                    trap();
                }
            }
        }
lijian6's avatar
lijian6 committed
831
        syncwarp();
lijian6's avatar
lijian6 committed
832
833

        // 移动缓存的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
834
835
        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;

lijian6's avatar
lijian6 committed
836
837
838
        // 等待共享内存被清理
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
839

lijian6's avatar
lijian6 committed
840
841
842
843
        // 开始准备处理接受数据,直到所有的数据接受完成。
        // 转发从RDMA缓冲区的令牌
        // 注意:总是从本地秩开始
        int src_rdma_rank = sm_id % kNumRDMARanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
844
845
        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
lijian6's avatar
lijian6 committed
846
847
        while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
            // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
lijian6's avatar
lijian6 committed
848
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
849
850
851

            // 用于给kNVLReceivers进行互动,控制数据的传输速度
            while(lane_id == 0) {
lijian6's avatar
lijian6 committed
852
                int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
lijian6's avatar
lijian6 committed
853
                if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
854
                    break;
lijian6's avatar
lijian6 committed
855
                cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
856

lijian6's avatar
lijian6 committed
857
858
859
860
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
861
862
863
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
864
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
865

lijian6's avatar
lijian6 committed
866
            // 找到下一个源RDMA秩(轮询)
lijian6's avatar
lijian6 committed
867
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
868
            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
869
                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
lijian6's avatar
lijian6 committed
870
871
872
873
874
                if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
                    if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));

                    if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
875
                        break;
lijian6's avatar
lijian6 committed
876
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
877
878
                }

lijian6's avatar
lijian6 committed
879
880
881
882
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
Chenggang Zhao's avatar
Chenggang Zhao committed
883
884
885
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
886

lijian6's avatar
lijian6 committed
887
888
            auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
            auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
889

lijian6's avatar
lijian6 committed
890
891
892
893
894
            // 遍历RDMA缓冲区中的每一个令牌
            for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
                // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
                void* shifted           = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
lishen's avatar
lishen committed
895
                auto src_meta           = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted)));
lijian6's avatar
lijian6 committed
896
897
898
899
                if(lane_id == src_rdma_rank) {
                    num_tokens_to_recv_from_rdma -= 1;
                }

Chenggang Zhao's avatar
Chenggang Zhao committed
900
                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
lijian6's avatar
lijian6 committed
901
                if(lane_id == src_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
902
903
                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
                    rdma_nvl_token_idx += is_in_dst_nvl_rank;
lijian6's avatar
lijian6 committed
904
                    if(!kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
905
906
                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
                }
lijian6's avatar
lijian6 committed
907
908

                if(!is_in_dst_nvl_rank)
Chenggang Zhao's avatar
Chenggang Zhao committed
909
910
                    continue;

lijian6's avatar
lijian6 committed
911
                // 获取一个空闲槽位
lijian6's avatar
lijian6 committed
912
                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
913

lishen's avatar
lishen committed
914
915
916
917
918
919
920
921
922
923
                // 设置 src和dst 位置
                auto src_gpu_buffer_x = reinterpret_cast<int4*>(reinterpret_cast<int8_t*>(shifted) + sizeof(SourceMeta));
                auto src_gpu_buffer_scales = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_x) + hidden_bytes);
                auto src_gpu_buffer_topk_idx = reinterpret_cast<int*>(reinterpret_cast<int8_t*>(src_gpu_buffer_scales) + num_scales * sizeof(float));
                auto src_gpu_buffer_topk_weights = reinterpret_cast<float*>(reinterpret_cast<int8_t*>(src_gpu_buffer_topk_idx) + num_topk * sizeof(int));

                auto dst_gpu_buffer_x = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                auto dst_gpu_buffer_scales = nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales;
                auto dst_gpu_buffer_topk_idx = nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk;
                auto dst_gpu_buffer_topk_weights = nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk;
lijian6's avatar
lijian6 committed
924

lishen's avatar
lishen committed
925
926
927
928
929
930
931
932
933
                if(lane_id == 0) {
                    st_na_global(reinterpret_cast<int64_t*>(nvl_channel_src_meta.buffer() + dst_slot_idx), 
                                 *reinterpret_cast<int64_t*>(&src_meta));
                }

                UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
                    dst_gpu_buffer_x,
                    src_gpu_buffer_x,
                    ld_direct_global, st_na_global);
lijian6's avatar
lijian6 committed
934
935

                UNROLLED_WARP_COPY(1, lane_id, num_scales,
lishen's avatar
lishen committed
936
937
938
939
940
941
942
943
944
945
946
947
                    dst_gpu_buffer_scales,
                    src_gpu_buffer_scales,
                    ld_direct_global, st_na_global);

                for(int t = lane_id; t < num_topk; t += kWarpSize) {
                    int idx_val = ld_direct_global(reinterpret_cast<int*>(src_gpu_buffer_topk_idx) + t);
                    float w_val = ld_direct_global(reinterpret_cast<float*>(src_gpu_buffer_topk_weights) + t);
                    int new_idx = (idx_val >= dst_rank_expert_begin && idx_val < dst_rank_expert_end) 
                                ? (idx_val - dst_rank_expert_begin) : -1;
                    float new_w = (new_idx != -1) ? w_val : 0.0f;
                    dst_gpu_buffer_topk_idx[t] = new_idx;
                    dst_gpu_buffer_topk_weights[t] = new_w;
Chenggang Zhao's avatar
Chenggang Zhao committed
948
949
                }

lijian6's avatar
lijian6 committed
950
951
                // 在NVL缓冲区不足的情况下,提前停止
                if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
952
953
954
                    src_rdma_tail = i + 1;
            }

lijian6's avatar
lijian6 committed
955
956
957
            // 同步头部索引
            if(lane_id == src_rdma_rank)
                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
958

lijian6's avatar
lijian6 committed
959
            // 移动尾部索引,与kNVLReceivers互相通信使用
lijian6's avatar
lijian6 committed
960
            syncwarp();
lijian6's avatar
lijian6 committed
961
962
963
            if(lane_id == 0) {
                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
964
965
966
        }

        // Retired
lijian6's avatar
lijian6 committed
967
        syncwarp();
lijian6's avatar
lijian6 committed
968
        if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
969
            forward_channel_retired[dst_nvl_rank] = true;
lijian6's avatar
lijian6 committed
970
971
972
973
974
975
976
977
978
        }
    } else if(warp_role == WarpRole::kForwarderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
        它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
        然后,它清理共享内存,并初始化转发通道的头部和退役状态。
        接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
        否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
        */
Chenggang Zhao's avatar
Chenggang Zhao committed
979
        // Extra warps for forwarder coordinator should exit directly
lijian6's avatar
lijian6 committed
980
        if (warp_id > NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
981
982
            return;

lijian6's avatar
lijian6 committed
983
984
985
986
987
988
        // 转发warp协调器
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        // 清理共享内存
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
        for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
989
            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
lijian6's avatar
lijian6 committed
990
        if(lane_id < NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
991
            forward_channel_retired[lane_id] = false;
lijian6's avatar
lijian6 committed
992
993
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
994
995

        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
lijian6's avatar
lijian6 committed
996
997
998

        while(true) {
            // 找到最小的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
999
            int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1000
#pragma unroll
lijian6's avatar
lijian6 committed
1001
1002
            for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                if(!forward_channel_retired[i])
lijian6's avatar
lijian6 committed
1003
                    min_head = min(min_head, forward_channel_head[i][target_rdma]);
lijian6's avatar
lijian6 committed
1004
1005

            if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1006
                break;
lijian6's avatar
lijian6 committed
1007
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1008

lijian6's avatar
lijian6 committed
1009
1010
            // 更新远程头部
            if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
lijian6's avatar
lijian6 committed
1011
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1012
1013
1014
1015
1016
                shmem_ctx_ulong_atomic_add(ctx,
#else
                shmem_signal_op_add(
#endif
                    rdma_channel_head.buffer(rdma_rank), min_head - last_head,
lijian6's avatar
lijian6 committed
1017
                    translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
1018
1019
                last_head = min_head;
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1020

lijian6's avatar
lijian6 committed
1021
            // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1022
            __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1023
        }
lijian6's avatar
lijian6 committed
1024
1025
1026
1027
1028
1029
1030
1031
    } else if(warp_role == WarpRole::kNVLReceivers) {
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }

        // Place the main logic of your kernel here, using the parameters above.
        // NVL消费者
        // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
Chenggang Zhao's avatar
Chenggang Zhao committed
1032
        int src_nvl_rank = target_rank, total_offset = 0;
lijian6's avatar
lijian6 committed
1033
1034
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1035
1036
            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];

lijian6's avatar
lijian6 committed
1037
1038
        // 接收通道偏移
        int start_offset = 0, end_offset = 0, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
1039
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1040
1041

        while(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1042
            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1043
            end_offset   = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1044
            if(start_offset < 0 && end_offset < 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1045
1046
1047
1048
                start_offset = -start_offset - 1, end_offset = -end_offset - 1;
                total_offset += start_offset;
                break;
            }
lijian6's avatar
lijian6 committed
1049
1050
1051
1052
            // 超时检查
            if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
                        channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
1053
1054
1055
                trap();
            }
        }
lijian6's avatar
lijian6 committed
1056

Chenggang Zhao's avatar
Chenggang Zhao committed
1057
1058
        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

lijian6's avatar
lijian6 committed
1059
1060
1061
        // 保存以供合并使用
        if(lane_id < kNumRDMARanks && !kCachedMode)
            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
lijian6's avatar
lijian6 committed
1062
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1063
1064

        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1065
1066
        while(num_tokens_to_recv > 0) {
            // 通过通道0检查通道状态
lijian6's avatar
lijian6 committed
1067
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1068
1069
1070
            while(lane_id == 0) {
                // 准备复制
                if(cached_channel_head_idx != cached_channel_tail_idx)
Chenggang Zhao's avatar
Chenggang Zhao committed
1071
                    break;
lijian6's avatar
lijian6 committed
1072
1073
1074
1075
1076
                cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
                            channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1077
1078
1079
1080
                    trap();
                }
            }

lijian6's avatar
lijian6 committed
1081
            // 同步队列尾部
lijian6's avatar
lijian6 committed
1082
1083
            cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);

lijian6's avatar
lijian6 committed
1084
            // 复制数据
Chenggang Zhao's avatar
Chenggang Zhao committed
1085
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1086
1087
1088
1089
            for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
                auto meta               = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
                int64_t recv_token_idx  = shfl_sync(total_offset, meta.src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1090
1091
                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;

lijian6's avatar
lijian6 committed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
                // 复制数据
                UNROLLED_WARP_COPY(5,
                                lane_id,
                                hidden_int4,
                                recv_x + recv_token_idx * hidden_int4,
                                nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
                                ld_nc_global,
                                st_na_global);

                // 复制源元数据
                if(lane_id == 0 && !kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
1103
                    st_na_global(recv_src_meta + recv_token_idx, meta);
lijian6's avatar
lijian6 committed
1104

lijian6's avatar
lijian6 committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
                // 复制比例
                UNROLLED_WARP_COPY(1,
                                lane_id,
                                num_scales,
                                recv_x_scales + recv_token_idx * num_scales,
                                nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
                                ld_nc_global,
                                st_na_global);

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
lijian6's avatar
lijian6 committed
1116
1117
                    auto recv_idx   = recv_token_idx * num_topk + lane_id;
                    auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
lijian6's avatar
lijian6 committed
1118
1119
                    st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
                    st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1120
1121
1122
                }
            }

lijian6's avatar
lijian6 committed
1123
            // 移动队列
lijian6's avatar
lijian6 committed
1124
            syncwarp();
lijian6's avatar
lijian6 committed
1125
            if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1126
                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
lijian6's avatar
lijian6 committed
1127
1128
            }
        } // while(num_tokens_to_recv > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1129
    }
lijian6's avatar
lijian6 committed
1130
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1131
1132
    shmem_wg_ctx_destroy(&ctx);
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1133
1134
}

lijian6's avatar
lijian6 committed
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1148
    constexpr int kNumDispatchRDMASenderWarps = 7;
lijian6's avatar
lijian6 committed
1149
1150
    // Make sure never OOB
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
lijian6's avatar
lijian6 committed
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175

#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                       \
    {                                                                                              \
        auto dispatch_func =                                                                       \
            low_latency_mode                                                                       \
                ? (is_cached_dispatch                                                              \
                       ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>         \
                       : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>)       \
                : (is_cached_dispatch                                                              \
                       ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>        \
                       : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>);     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx,   \
            recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta),                      \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head,   \
            send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix,        \
            rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix,      \
            recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales,       \
            num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,       \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
1176
1177
    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

lijian6's avatar
lijian6 committed
1178
1179
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
                        (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1180
1181
1182
1183
    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
1184
template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
1185
__global__ void __launch_bounds__(1024, 1)
lijian6's avatar
lijian6 committed
1186
1187
1188
1189
1190
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
              const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
              int num_channels, const int *rdma_channel_prefix_matrix,
              const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
              void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
1191
              bool is_cached_dispatch, const shmem_team_t rdma_team) {
lijian6's avatar
lijian6 committed
1192
1193
    auto sm_id       = static_cast<int>(blockIdx.x);
    auto thread_id   = static_cast<int>(threadIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
1194
    auto num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
1195
1196
1197
    auto num_warps   = num_threads / kWarpSize;
    auto warp_id     = thread_id / kWarpSize;
    auto lane_id     = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
1198

lijian6's avatar
lijian6 committed
1199
    auto nvl_rank       = rank % NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1200
1201
1202
1203
1204
    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Using two SMs, which clean the RDMA/NVL buffer respectively
    if (sm_id == 0) {
        // Barrier for RDMA
lishen's avatar
lishen committed
1205
        if (thread_id == kWarpSize)
lijian6's avatar
lijian6 committed
1206
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Shangyan Zhou's avatar
Fix  
Shangyan Zhou committed
1207

lishen's avatar
lishen committed
1208
1209
        // Barrier for NVL
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1210

lishen's avatar
lishen committed
1211
        // Clean RDMA buffer
lijian6's avatar
lijian6 committed
1212
        auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
1213
1214
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1215

lishen's avatar
lishen committed
1216
        // Clean NVL buffer
lijian6's avatar
lijian6 committed
1217
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
1218
1219
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
lishen's avatar
lishen committed
1220

1221
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1222

lishen's avatar
lishen committed
1223
1224
1225
1226
        // Barrier again
        if (thread_id == kWarpSize)
            dushmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

Chenggang Zhao's avatar
Chenggang Zhao committed
1227
        // Barrier again
1228
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lishen's avatar
lishen committed
1229
    } else if (sm_id == 1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1230
1231
1232
1233
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1234
        EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
1235
1236
1237
1238

        // Iterate in reverse order
        if (lane_id < num_rdma_ranks and warp_id < num_channels) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1239
1240
            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1241
1242
1243

            // NOTES: `1 << 25` is a heuristic large number
            int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1244
1245
1246
            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                auto current_head =
                    __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
                if (current_head < 0) {
                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
                } else {
                    last_head = current_head;
                }
            }
        }
    } else {
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1259
1260
1261
        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
                                  rdma_rank_prefix_sum != nullptr);
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
lishen's avatar
lishen committed
1262
        constexpr int num_clean_sms = 2;
1263

lijian6's avatar
lijian6 committed
1264
        if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
lishen's avatar
lishen committed
1265
1266
            for (int dst_rdma_rank = sm_id - num_clean_sms; dst_rdma_rank < num_rdma_ranks;
                 dst_rdma_rank += num_channels * 2 - num_clean_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1267
                // Iterate in reverse order
lijian6's avatar
lijian6 committed
1268
1269
1270
1271
1272
1273
                int token_start_idx =
                    warp_id == 0
                        ? 0
                        : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
                int token_end_idx =
                    rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
1274
1275
1276
1277
1278
                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
                token_start_idx += shift, token_end_idx += shift;

                // NOTES: `1 << 25` is a heuristic large number
                int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1279
1280
1281
1282
1283
1284
1285
                for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                    auto current_head =
                        __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
                    if (current_head < 0) {
                        combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
                    } else {
                        last_head = current_head;
1286
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1287
1288
1289
1290
1291
1292
1293
                }
            }
        }
    }
}

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
1294
1295
1296
1297
1298
1299
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
1300
                   bool is_cached_dispatch, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1301
    const int  num_threads    = ::max(128, kWarpSize * num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1302
1303
1304
    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Get clean meta
lijian6's avatar
lijian6 committed
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
1315
1316
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
lishen's avatar
lishen committed
1317
    EP_HOST_ASSERT(num_channels * 2 > 2);
Chenggang Zhao's avatar
Chenggang Zhao committed
1318
1319

    // Launch kernel
lijian6's avatar
lijian6 committed
1320
    auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
Chenggang Zhao's avatar
Chenggang Zhao committed
1321
    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
lijian6's avatar
lijian6 committed
1322
1323
1324
1325
1326
1327
    LAUNCH_KERNEL_NON_COOPERATIVE(
        &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
        nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
        num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
        rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
        cpu_rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1328
1329
}

lijian6's avatar
lijian6 committed
1330
1331
1332
1333
1334
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
                             int lane_id, int hidden_int4, int num_topk,
                             int4* combined_row, float* combined_topk_weights,
                             int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1335
1336
1337
1338
    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);

    // Broadcast current heads
    // Lane `i` holds the head of rank `i` and `is_token_in_rank`
lijian6's avatar
lijian6 committed
1339
    EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
Chenggang Zhao's avatar
Chenggang Zhao committed
1340
    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
lijian6's avatar
lijian6 committed
1341
1342
1343
1344
1345
    #pragma unroll
    for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
        slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
        topk_ranks[num_topk_ranks ++] = i;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1346
1347
1348
    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);

    // Reduce data
lijian6's avatar
lijian6 committed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
    #pragma unroll
    for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
        // Read buffers
        // TODO: maybe too many registers here
        int4 recv_value_int4[kMaxNumRanks];
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j)
            recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);

        // Reduce all-to-all results
lijian6's avatar
lijian6 committed
1359
        float values[kDtypePerInt4] = {0};
lijian6's avatar
lijian6 committed
1360
1361
1362
1363
1364
1365
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j) {
            auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
            #pragma unroll
            for (int k = 0; k < kDtypePerInt4; ++ k)
                values[k] += static_cast<float>(recv_value_dtypes[k]);
Shangyan Zhou's avatar
Shangyan Zhou committed
1366
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1367

lijian6's avatar
lijian6 committed
1368
1369
1370
1371
1372
        // Cast back to `dtype_t` and write
        int4 out_int4;
        auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
        #pragma unroll
        for (int j = 0; j < kDtypePerInt4; ++ j)
lijian6's avatar
lijian6 committed
1373
1374
            out_dtypes[j] = static_cast<dtype_t>(values[j]);
        st_na_global(combined_row + i, out_int4);
Chenggang Zhao's avatar
Chenggang Zhao committed
1375
1376
1377
1378
1379
    }

    // Reduce `topk_weights`
    if (lane_id < num_topk) {
        float value = 0;
lijian6's avatar
lijian6 committed
1380
1381
        #pragma unroll
        for (int i = 0; i < num_topk_ranks; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1382
1383
1384
1385
1386
1387
1388
1389
            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
        st_na_global(combined_topk_weights + lane_id, value);
    }

    // Return the minimum top-k rank
    return topk_ranks[0];
}

lijian6's avatar
lijian6 committed
1390
1391
1392
1393
template <bool kLowLatencyMode,
          int kNumRDMARanks,
          typename dtype_t,
          int kNumCombineForwarderWarps,
lijian6's avatar
lijian6 committed
1394
          int kNumTopkRDMARanks     = get_num_topk_rdma_ranks(kNumRDMARanks),
lijian6's avatar
lijian6 committed
1395
          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
lijian6's avatar
lijian6 committed
1396
          int kNumForwarders        = kNumRDMARanks * kNumWarpsPerForwarder,
lijian6's avatar
lijian6 committed
1397
1398
1399
          int kNumRDMAReceivers     = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) 
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
lijian6's avatar
lijian6 committed
1400
1401
1402
1403
1404
1405
1406
1407
            const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
            const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
            const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
            const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
            int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
            int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
            int num_ranks) {
lijian6's avatar
lijian6 committed
1408
1409
1410
1411
1412
1413
1414
    enum class WarpRole {
        kNVLSender,
        kNVLAndRDMAForwarder,
        kRDMAReceiver,
        kRDMACoordinator,
        kNVLCoordinator
    };
lijian6's avatar
lijian6 committed
1415
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1416
1417
1418
    __shared__ shmem_ctx_t ctx;
    shmem_wg_ctx_create(&ctx);
#endif
lijian6's avatar
lijian6 committed
1419

lijian6's avatar
lijian6 committed
1420
1421
1422
1423
1424
1425
    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

Chenggang Zhao's avatar
Chenggang Zhao committed
1426
1427
1428
1429
    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));

    // NOTES: we decouple a channel into 2 SMs
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
1430
1431
1432
1433
1434
1435
1436

    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumForwarders) {
                return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
Chenggang Zhao's avatar
Chenggang Zhao committed
1437
            } else {
lijian6's avatar
lijian6 committed
1438
                return {WarpRole::kRDMACoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1439
1440
            }
        } else {
lijian6's avatar
lijian6 committed
1441
            if(warp_id < kNumForwarders) {
lijian6's avatar
lijian6 committed
1442
                return {WarpRole::kRDMAReceiver, warp_id};
Chenggang Zhao's avatar
Chenggang Zhao committed
1443
            } else {
lijian6's avatar
lijian6 committed
1444
                return {WarpRole::kNVLCoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1445
1446
1447
            }
        }
    }();
lijian6's avatar
lijian6 committed
1448

Chenggang Zhao's avatar
Chenggang Zhao committed
1449
    auto warp_role = role_meta.first;
lijian6's avatar
lijian6 committed
1450
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
Chenggang Zhao's avatar
Chenggang Zhao committed
1451

lijian6's avatar
lijian6 committed
1452
    EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
1453
1454
    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1455
    // This approach is designed to sync multiple warps in a loop
lijian6's avatar
lijian6 committed
1456
1457
1458
1459
    constexpr int num_sync_large_iteration = 64;
    constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
    __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];   
    for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
lijian6's avatar
lijian6 committed
1460
1461
1462
1463
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
1464
    if (warp_role == WarpRole::kNVLSender) {
lijian6's avatar
lijian6 committed
1465
1466
1467
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1468

lijian6's avatar
lijian6 committed
1469
        const auto dst_nvl_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
1470
1471
1472
        // NVL layouts
        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
1473
1474
1475
1476
1477
        auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
1478

Chenggang Zhao's avatar
Chenggang Zhao committed
1479
1480
        // Get tasks for each RDMA lane
        int token_start_idx = 0, token_end_idx = 0;
lijian6's avatar
lijian6 committed
1481
1482
        if(lane_id < kNumRDMARanks) {
            int prefix_idx  = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
Chenggang Zhao's avatar
Chenggang Zhao committed
1483
            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
lijian6's avatar
lijian6 committed
1484
            token_end_idx   = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1485
        }
lijian6's avatar
lijian6 committed
1486
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1487
1488
1489

        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1490
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1491
1492

        // Iterate over all tokens and send by chunks
lijian6's avatar
lijian6 committed
1493
        while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1494
            // Exit if possible
lijian6's avatar
lijian6 committed
1495
            if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1496
1497
                break;

lijian6's avatar
lijian6 committed
1498
            // Decide next RDMA buffer to send
Chenggang Zhao's avatar
Chenggang Zhao committed
1499
            bool is_lane_ready = false;
lijian6's avatar
lijian6 committed
1500
            auto start_time    = wall_clock64();
lijian6's avatar
lijian6 committed
1501
1502

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1503
                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1504
                is_lane_ready      = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
lijian6's avatar
lijian6 committed
1505
1506
1507
                                num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;

                if(__any_sync(kFullWarpMask, is_lane_ready))
Chenggang Zhao's avatar
Chenggang Zhao committed
1508
                    break;
lijian6's avatar
lijian6 committed
1509

Chenggang Zhao's avatar
Chenggang Zhao committed
1510
                // Retry
lijian6's avatar
lijian6 committed
1511
1512
                if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1513
1514

                // Timeout check
lijian6's avatar
lijian6 committed
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
                if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
                        "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
                        channel_id,
                        rdma_rank,
                        nvl_rank,
                        dst_nvl_rank,
                        lane_id,
                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),
                        cached_channel_tail_idx,
                        token_start_idx,
                        token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1527
1528
1529
1530
1531
                    trap();
                }
            }

            // Sync token start index and count
lijian6's avatar
lijian6 committed
1532
1533
            for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
                if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1534
1535
1536
                    continue;

                // Sync token start index
lijian6's avatar
lijian6 committed
1537
1538
                auto token_idx          = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
                int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1539
1540

                // Send by chunk
lijian6's avatar
lijian6 committed
1541
                for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1542
1543
                    // Get an empty slot
                    int dst_slot_idx = 0;
lijian6's avatar
lijian6 committed
1544
1545
1546
                    if(lane_id == current_rdma_idx) {
                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
1547
                    }
lijian6's avatar
lijian6 committed
1548
                    dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
lijian6's avatar
lijian6 committed
1549
1550
1551
1552

                    // Copy data
                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                    auto shifted_x         = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
1553
                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
1554

lijian6's avatar
lijian6 committed
1555
                    // Copy source meta
lijian6's avatar
lijian6 committed
1556
1557
                    if(lane_id == 0)
                        st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1558

lijian6's avatar
lijian6 committed
1559
                    // Copy `topk_weights`
lijian6's avatar
lijian6 committed
1560
1561
1562
                    if(lane_id < num_topk)
                        st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
                                    ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1563
1564
1565
1566
1567
                }
                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
            }

            // Move queue tail
lijian6's avatar
lijian6 committed
1568
            syncwarp();
lijian6's avatar
lijian6 committed
1569
1570
1571
            if(lane_id < kNumRDMARanks and is_lane_ready) {
                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1572
1573
        }
    } else {
lijian6's avatar
lijian6 committed
1574
1575
1576
1577
        if(warp_id > kNumForwarders) {
            return;
        }

Chenggang Zhao's avatar
Chenggang Zhao committed
1578
1579
        // Combiners and coordinators
        // RDMA symmetric layout
lijian6's avatar
lijian6 committed
1580
        auto hidden_bytes = hidden_int4 * sizeof(int4);
lijian6's avatar
lijian6 committed
1581
        auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
lijian6's avatar
lijian6 committed
1582
1583
1584
        auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1585
1586

        // NVL layouts
lijian6's avatar
lijian6 committed
1587
1588
1589
1590
        void* local_nvl_buffer = buffer_ptrs[nvl_rank];
        void* nvl_buffers[NUM_MAX_NVL_PEERS];
        #pragma unroll
        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1591
            nvl_buffers[i] = buffer_ptrs[i];
lijian6's avatar
lijian6 committed
1592
1593
1594
1595
1596
        auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
Chenggang Zhao's avatar
Chenggang Zhao committed
1597
1598

        // Combiner warp synchronization
lijian6's avatar
lijian6 committed
1599
        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
1600
        __shared__ volatile bool forwarder_retired[kNumForwarders];
lijian6's avatar
lijian6 committed
1601
        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
1602
        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
lijian6's avatar
lijian6 committed
1603

Chenggang Zhao's avatar
Chenggang Zhao committed
1604
1605
1606
        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
            // Receive from NVL ranks and forward to RDMA ranks
            // NOTES: this part is using "large warps" for each RDMA ranks
lijian6's avatar
lijian6 committed
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
            const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
            const auto sub_warp_id   = target_rank % kNumWarpsPerForwarder;
            auto send_buffer         = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
            // auto sync_large_warp     = [=]() {
            //     if(kNumWarpsPerForwarder == 1) {
            //         syncwarp();
            //     } else {
            //         // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
            //         // __syncthreads();
            //         syncwarp();
            //     }
            // };
            auto sync_large_warp = [=](const int iter, const int mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1620
                if (kNumWarpsPerForwarder == 1) {
lijian6's avatar
lijian6 committed
1621
                    syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1622
                } else {
lijian6's avatar
lijian6 committed
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
                        // LDS index to store for sync
                        int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        //reset index in the LDS to avoid race condition due to warp scheduling
                        int reset_idx =         dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        auto start_time = wall_clock64();
                        if (lane_id == 0){
                            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
                        }
                        syncwarp();
                        //The while(...) loop polls the counter until all warps have arrived
                        if (lane_id == 0){
                            while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
                                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                                    printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
                                    trap();
                                }
lijian6's avatar
lijian6 committed
1639
1640
                            }
                        }
lijian6's avatar
lijian6 committed
1641
1642
1643
1644
1645
                        syncwarp();
                        if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
                            sync_large_warp_counters[reset_idx] = 0;
                        }
                        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1646
1647
                }
            };
lijian6's avatar
lijian6 committed
1648
            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
1649

lijian6's avatar
lijian6 committed
1650
1651
            // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
lijian6's avatar
lijian6 committed
1652
            nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
lijian6's avatar
lijian6 committed
1653
            nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
1654
1655
1656
1657
            nvl_channel_head.advance(dst_rdma_rank);
            nvl_channel_tail.advance(dst_rdma_rank);

            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1658
1659
1660
1661
1662
            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
            // sync_forwarder_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1663
1664
1665

            // Get count and cached head
            int cached_nvl_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1666
1667
            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1668
1669
1670
1671
1672
            num_tokens_to_combine -= num_tokens_prefix;
            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;

            // Iterate over all tokens and combine by chunks
lijian6's avatar
lijian6 committed
1673
            for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1674
                // Check destination queue emptiness, or wait a buffer to be released
lijian6's avatar
lijian6 committed
1675
                auto token_end_idx      = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
Chenggang Zhao's avatar
Chenggang Zhao committed
1676
                auto num_chunked_tokens = token_end_idx - token_start_idx;
lijian6's avatar
lijian6 committed
1677
                auto start_time         = wall_clock64();
lijian6's avatar
lijian6 committed
1678
1679
1680
1681
1682
1683
                while(sub_warp_id == 0 and lane_id == 0) {
                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
                    // Here, `token_start_idx` is the actual tail
                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));

                    if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
1684
1685
1686
                        break;

                    // Timeout check
lijian6's avatar
lijian6 committed
1687
1688
1689
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
                                channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
1690
1691
1692
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1693
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1694
                sync_large_warp(token_start_idx, 0);
lijian6's avatar
lijian6 committed
1695

Chenggang Zhao's avatar
Chenggang Zhao committed
1696
                // Combine and write to the RDMA buffer
lijian6's avatar
lijian6 committed
1697
                for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1698
                    // Read expected head
lijian6's avatar
lijian6 committed
1699
                    EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1700
                    int expected_head = -1;
lijian6's avatar
lijian6 committed
1701
1702
                    if(lane_id < NUM_MAX_NVL_PEERS)
                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1703
1704

                    // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1705
                    start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1706
1707
                    while(cached_nvl_channel_tail_idx <= expected_head) {
                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1708
1709

                        // Timeout check
lijian6's avatar
lijian6 committed
1710
1711
1712
                        if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
                            printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
                                    channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1713
1714
1715
1716
1717
                            trap();
                        }
                    }

                    // Combine current token
lijian6's avatar
lijian6 committed
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
                    auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
                    combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
                                                                                    expected_head, lane_id,
                                                                                    hidden_int4, num_topk,
                                                                                    reinterpret_cast<int4*>(shifted),
                                                                                    reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
                                                                                    num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1728
1729

                    // Update head
lijian6's avatar
lijian6 committed
1730
1731
1732
1733
                    if(lane_id < NUM_MAX_NVL_PEERS) {
                        expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
                                        : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1734
                }
lijian6's avatar
lijian6 committed
1735
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1736
                sync_large_warp(token_start_idx, 1);
lijian6's avatar
lijian6 committed
1737

Chenggang Zhao's avatar
Chenggang Zhao committed
1738
                // Issue RDMA send
lijian6's avatar
lijian6 committed
1739
1740
                if(sub_warp_id == kNumWarpsPerForwarder - 1) {
                    if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1741
                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
1742
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1743
1744
1745
1746
                        shmem_ctx_schar_put_nbi_warp(ctx,
#else
                        shmemx_int8_put_nbi_warp(
#endif
lijian6's avatar
lijian6 committed
1747
1748
1749
1750
1751
1752
                            rdma_channel_data.recv_buffer(rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            rdma_channel_data.send_buffer(dst_rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            num_chunked_tokens * num_bytes_per_rdma_token,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1753
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1754
1755
1756
1757
                        shmem_ctx_quiet(ctx);                
#else
                        shmem_fence();
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1758
1759
1760
1761
1762
                    } else {
                        memory_fence();
                    }

                    // Write new RDMA tail
lijian6's avatar
lijian6 committed
1763
                    syncwarp();
lijian6's avatar
lijian6 committed
1764
                    if(lane_id == 0) {
lijian6's avatar
lijian6 committed
1765
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1766
1767
1768
1769
1770
                        shmem_ctx_ulong_atomic_add(ctx,
#else
                        shmem_signal_op_add(
#endif
                            rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
lijian6's avatar
lijian6 committed
1771
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1772
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1773
1774
1775
1776
                }
            }

            // Retired
lijian6's avatar
lijian6 committed
1777
            syncwarp();
lijian6's avatar
lijian6 committed
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
            if(lane_id == 0) {
                forwarder_retired[target_rank] = true;
            }
        } else if (warp_role == WarpRole::kRDMACoordinator) {
            // Coordinator
            // Sync shared memory status
            // sync_forwarder_smem();
            __syncthreads();
            constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

            int last_nvl_head[kNumRDMARanks] = {0};
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
                // Retired
                if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
                    break;

                {
                    // Find minimum head for NVL ranks
                    #pragma unroll
                    for(int i = 0; i < kNumRDMARanks; ++i) {
                        int min_head = std::numeric_limits<int>::max();
                        #pragma unroll
                        for(int j = 0; j < num_warps_per_rdma_rank; ++j)
                            if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);

                        if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
                        }
                    }
                }

                // Nanosleep and let other warps work
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
            }
        } else if(warp_role == WarpRole::kRDMAReceiver) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1817
1818
            // Receive from RDMA ranks and write to the output tensor
            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1819
            EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lijian6's avatar
lijian6 committed
1820
1821
1822
1823
            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1824
1825
1826

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1827
            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1828
1829
1830

            // Iterate over all tokens and combine
            int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1831
            for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1832
                // Read expected head
lijian6's avatar
lijian6 committed
1833
                EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1834
                int expected_head = -1;
lijian6's avatar
lijian6 committed
1835
1836
1837
1838
                if(lane_id < kNumRDMARanks) {
                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
                    (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
                                        : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1839
1840
1841
                }

                // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1842
                auto start_time = wall_clock64();
Chenggang Zhao's avatar
Chenggang Zhao committed
1843
                while (cached_channel_tail_idx <= expected_head) {
lijian6's avatar
lijian6 committed
1844
                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
Chenggang Zhao's avatar
Chenggang Zhao committed
1845
1846

                    // Timeout check
lijian6's avatar
lijian6 committed
1847
1848
1849
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
                                channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1850
1851
1852
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1853
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1854
1855

                // Combine current token
lijian6's avatar
lijian6 committed
1856
1857
1858
1859
1860
1861
1862
1863
                auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
                combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
                                                                            expected_head, lane_id,
                                                                            hidden_int4, num_topk,
                                                                            combined_x + token_idx * hidden_int4,
                                                                            combined_topk_weights + token_idx * num_topk,
                                                                            num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1864
1865
1866
            }

            // Retired
lijian6's avatar
lijian6 committed
1867
            syncwarp();
lijian6's avatar
lijian6 committed
1868
1869
1870
1871
            if(lane_id == 0) {
                rdma_receiver_retired[target_rank] = true;
            }
        } else if(warp_role == WarpRole::kNVLCoordinator) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1872
1873
            // Coordinator
            // Sync shared memory status
lijian6's avatar
lijian6 committed
1874
1875
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1876
1877
            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1878
            int last_rdma_head               = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
1879
            int last_nvl_head[kNumRDMARanks] = {0};
lijian6's avatar
lijian6 committed
1880
1881
            int dst_rdma_rank                = lane_id < kNumRDMARanks ? lane_id : 0;
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
lijian6's avatar
lijian6 committed
1882
1883
1884
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1885
                // Retired
lijian6's avatar
lijian6 committed
1886
                if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
Chenggang Zhao's avatar
Chenggang Zhao committed
1887
1888
1889
                    break;

                // Find minimum head for RDMA ranks
lijian6's avatar
lijian6 committed
1890
                {
Chenggang Zhao's avatar
Chenggang Zhao committed
1891
                    int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1892
1893
1894
    #pragma unroll
                    for(int i = 0; i < kNumRDMAReceivers; ++i)
                        if(not rdma_receiver_retired[i])
lijian6's avatar
lijian6 committed
1895
                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
lijian6's avatar
lijian6 committed
1896
1897

                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lijian6's avatar
lijian6 committed
1898
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1899
1900
1901
1902
1903
                        shmem_ctx_ulong_atomic_add(ctx,
#else
                        shmem_signal_op_add(
#endif
                            rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
lijian6's avatar
lijian6 committed
1904
1905
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

1906
1907
                        last_rdma_head = min_head;
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1908
1909
1910
                }

                // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1911
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1912
1913
1914
            }
        }
    }
lijian6's avatar
lijian6 committed
1915
#if !defined(FORCE_DUSHMEM_API) && !defined(ROCM_DISABLE_CTX)
1916
1917
    shmem_wg_ctx_destroy(&ctx);
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
1918
1919
}

lijian6's avatar
lijian6 committed
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1930
    constexpr int kNumCombineForwarderWarps = 8;
lijian6's avatar
lijian6 committed
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949

#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                                        \
    {                                                                                              \
        auto combine_func =                                                                        \
            low_latency_mode                                                                       \
                ? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>           \
                : combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>;         \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights,       \
            is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights,            \
            reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1),        \
            combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,           \
            num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr,                    \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

lijian6's avatar
lijian6 committed
1950
    int num_rdma_ranks           = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1951
    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
lijian6's avatar
lijian6 committed
1952
1953
    int num_forwarder_warps      = num_rdma_ranks * num_warps_per_forwarder;
    EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
lijian6's avatar
lijian6 committed
1954
    EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
1955
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
lijian6's avatar
lijian6 committed
1956
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
lijian6's avatar
lijian6 committed
1957
    EP_HOST_ASSERT(type == HIP_R_16BF);
Chenggang Zhao's avatar
Chenggang Zhao committed
1958

lijian6's avatar
lijian6 committed
1959
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1960
1961
1962
    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
1963

Chenggang Zhao's avatar
Chenggang Zhao committed
1964
1965
1966
} // namespace internode

} // namespace deep_ep
lijian6's avatar
lijian6 committed
1967

lijian6's avatar
lijian6 committed
1968
1969
1970
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
lijian6's avatar
lijian6 committed
1971
1972

#endif