internode.cu 105 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
3
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include "launch.cuh"
#include "utils.cuh"
lijian6's avatar
lijian6 committed
6
7
8
9
10
11

#ifndef DISABLE_ROCSHMEM

#include <rocshmem/rocshmem.hpp>

// TODO: fix unroll warnings
lijian6's avatar
lijian6 committed
12
13
14
15
16
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
Chenggang Zhao's avatar
Chenggang Zhao committed
17
18
19
20
21

namespace deep_ep {

namespace internode {

lijian6's avatar
lijian6 committed
22
extern rocshmem::rocshmem_team_t cpu_rdma_team;
Chenggang Zhao's avatar
Chenggang Zhao committed
23
24
25
26
27
28
29
30
31

struct SourceMeta {
    int src_rdma_rank, is_token_in_nvl_rank_bits;

    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");

    __forceinline__ SourceMeta() = default;

    // TODO: faster encoding
lijian6's avatar
lijian6 committed
32
33
    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
        src_rdma_rank             = rdma_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
34
        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
lijian6's avatar
lijian6 committed
35
36
#pragma unroll
        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
37
38
39
40
41
42
43
44
45
46
47
48
            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
    }

    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
        return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
    }
};

int get_source_meta_bytes() {
    return sizeof(SourceMeta);
}

lijian6's avatar
lijian6 committed
49
50
51
52
53
54
55
56
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
                                                                     int num_scales,
                                                                     int num_topk_idx,
                                                                     int num_topk_weights) {
    return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
                                      num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
                                      num_topk_weights * sizeof(float),
                                  sizeof(int4)));
Chenggang Zhao's avatar
Chenggang Zhao committed
57
58
}

lijian6's avatar
lijian6 committed
59
60
61
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                    int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
lijian6's avatar
lijian6 committed
62
    // Return `int32_t` offset and count to clean
lijian6's avatar
lijian6 committed
63
64
65
66
    return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
             num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
                sizeof(int),
            (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
Chenggang Zhao's avatar
Chenggang Zhao committed
67
}
lijian6's avatar
lijian6 committed
68
69
70
71
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
                   int num_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
72
    // Return `int32_t` offset and to clean
lijian6's avatar
lijian6 committed
73
74
    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
                              "Invalid size of `SourceMeta`");
Chenggang Zhao's avatar
Chenggang Zhao committed
75
    return {
lijian6's avatar
lijian6 committed
76
77
78
79
80
81
        (num_nvl_recv_buffer_tokens *
         (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
          num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
         num_nvl_ranks * num_sms) /
            sizeof(int),
        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
Chenggang Zhao's avatar
Chenggang Zhao committed
82
83
84
85
    };
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
86
87
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
                                                       const int nvl_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
88
89
90
91
    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
92
93
94
95
96
97
98
99
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
    // NOTE: shmem_device_barrier_all() might be an issue as
    // it doesn't follow OpenSHMEM specification on ROCm
    // kLowLatencyMode
    //     ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
    //     : rocshmem::rocshmem_barrier_all();
    rocshmem::rocshmem_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
100
101
102
103
}

template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
lijian6's avatar
lijian6 committed
104
105
106
107
108
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
                int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
Chenggang Zhao's avatar
Chenggang Zhao committed
109
                const int nvl_clean_offset, const int nvl_num_int_clean,
lijian6's avatar
lijian6 committed
110
111
112
113
114
115
116
117
                int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
                int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
                void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
                const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id     = static_cast<int>(blockIdx.x);
    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
         lane_id     = get_lane_id();
    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
118
119

    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
120
121
    auto num_rdma_experts = num_experts / kNumRDMARanks,
         num_nvl_experts  = num_rdma_experts / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
122
123
124

    if (sm_id == 0) {
        // Communication with others
lijian6's avatar
lijian6 committed
125
        // Global barrier: the first warp do intra-node sync, the second warp do internode sync
Chenggang Zhao's avatar
Chenggang Zhao committed
126
127
        EP_DEVICE_ASSERT(num_warps > 1);
        EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
lijian6's avatar
lijian6 committed
128
129
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
130

lijian6's avatar
lijian6 committed
131
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
132
133
        __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
134
        // Send numbers of tokens per rank/expert to RDMA ranks
lijian6's avatar
lijian6 committed
135
136
137
        auto rdma_buffer_ptr_int        = reinterpret_cast<int *>(rdma_buffer_ptr);
        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
            rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
Chenggang Zhao's avatar
Chenggang Zhao committed
138
139

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
140
141
        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
                                  rdma_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
142
143
144
145
146
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;

        // Copy to send buffer
        for (int i = thread_id; i < num_ranks; i += num_threads)
lijian6's avatar
lijian6 committed
147
148
            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
                num_tokens_per_rank[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
149
        for (int i = thread_id; i < num_experts; i += num_threads)
lijian6's avatar
lijian6 committed
150
151
152
            rdma_recv_num_tokens_mixed.send_buffer(
                i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
                num_tokens_per_expert[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
153
        if (thread_id < kNumRDMARanks)
lijian6's avatar
lijian6 committed
154
155
156
            rdma_recv_num_tokens_mixed.send_buffer(
                thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
                num_tokens_per_rdma_rank[thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
157
158
159
160
161
        __syncthreads();

        // Issue send
        // TODO: more light fence or barrier or signaling
        // TODO: overlap EP barrier and NVL cleaning
lijian6's avatar
lijian6 committed
162
163
164
165
166
167
        if (thread_id < kNumRDMARanks) {
            rocshmem::rocshmem_int_put_nbi(
                rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
                rdma_recv_num_tokens_mixed.send_buffer(thread_id),
                NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
                translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
168
        }
alpha-baby's avatar
alpha-baby committed
169
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
170
        if (thread_id == 0)
lijian6's avatar
lijian6 committed
171
172
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

Chenggang Zhao's avatar
Chenggang Zhao committed
173
174
175
176
177
        __syncthreads();

        // NVL buffers
        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
178
179
180
181
182
183
184
185
186
187
        auto nvl_reduced_num_tokens_per_expert =
            Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
        auto nvl_send_num_tokens_per_rank =
            AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_send_num_tokens_per_expert =
            AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_rank =
            AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_expert =
            AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
Chenggang Zhao's avatar
Chenggang Zhao committed
188
189

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
190
191
192
193
194
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
                                      nvl_send_num_tokens_per_rank.total_bytes +
                                      nvl_send_num_tokens_per_expert.total_bytes <=
                                  nvl_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
195
196
197
198
199
200
201
202
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;

        // Reduce number of tokens per expert into the NVL send buffer
        // TODO: may use NVSHMEM reduction
        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
        if (thread_id < num_rdma_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
203
204
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
205
206
207
208
209
210
211
212
                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
            nvl_reduced_num_tokens_per_expert[thread_id] = sum;
        }
        __syncthreads();

        // Reduce RDMA received tokens
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
213
214
215
216
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i) {
                sum +=
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
Chenggang Zhao's avatar
Chenggang Zhao committed
217
218
                recv_rdma_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
219
220
            while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
221
222
223
224
225
226
            *moe_recv_rdma_counter_mapped = sum;
        }

        // Send numbers of tokens per rank/expert to NVL ranks
        EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
227
228
229
230
231
232
233
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
            for (int i = 0; i < num_nvl_experts; ++i)
                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
                    nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
234
        }
lijian6's avatar
lijian6 committed
235
236
        memory_fence();
        __syncthreads();
237
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
238
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
239

lijian6's avatar
lijian6 committed
240
        // Reduce number of tokens per rank/expert
Chenggang Zhao's avatar
Chenggang Zhao committed
241
242
243
        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
244
            for (int i = 0; i < num_ranks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
245
246
247
248
                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
                recv_gbl_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
249
250
            while (ld_volatile_global(moe_recv_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
251
252
253
254
            *moe_recv_counter_mapped = sum;
        }
        if (thread_id < num_nvl_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
255
256
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
257
258
                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
lijian6's avatar
lijian6 committed
259
260
            while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
261
262
263
264
            moe_recv_expert_counter_mapped[thread_id] = sum;
        }

        // Finally barrier
lijian6's avatar
lijian6 committed
265
266
267
268
        __syncthreads();
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

269
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
270
271
272
273
274
    } else {
        // Calculate meta data
        int dst_rdma_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
275
276
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
277
278
279

            // Iterate over tokens
            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
lijian6's avatar
lijian6 committed
280
281
282
283
284
285
286
287
288
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
                                          "Invalid number of NVL peers");
                auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
                    is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
                auto is_token_in_rank_values =
                    reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
289
290
291
292
293
294
                    per_nvl_rank_count[j] += is_token_in_rank_values[j];
                total_count += (is_token_in_rank_uint64 != 0);
            }

            // Warp reduce
            total_count = warp_reduce_sum(total_count);
lijian6's avatar
lijian6 committed
295
296
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
297
298
299
                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);

            // Write into channel matrix
lijian6's avatar
lijian6 committed
300
301
302
303
304
305
            if (lane_id == 0) {
#pragma unroll
                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
                                                  num_channels +
                                              channel_id] = per_nvl_rank_count[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
306
307
308
309
310
311
                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
            }
        }

        // Calculate prefix sum
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
312
        if (thread_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
313
            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
lijian6's avatar
lijian6 committed
314
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
315
316
317
                prefix_row[i] += prefix_row[i - 1];
        }

lijian6's avatar
lijian6 committed
318
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
319
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
320
321
322
            auto prefix_row = gbl_channel_prefix_matrix +
                              (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
323
324
325
326
327
                prefix_row[i] += prefix_row[i - 1];
        }
    }
}

lijian6's avatar
lijian6 committed
328
329
330
331
332
333
334
335
336
337
338
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
339
                     bool low_latency_mode) {
lijian6's avatar
lijian6 committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                \
    {                                                                                              \
        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks>       \
                                                     : notify_dispatch<false, num_rdma_ranks>;     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks,   \
            num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert,         \
            moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens,             \
            num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second,         \
            nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix,               \
            recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,        \
            rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team);               \
    }                                                                                              \
    break

    constexpr int kNumThreads    = 256;
    const auto    num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
357
358

    // Get clean meta
lijian6's avatar
lijian6 committed
359
360
361
362
363
364
365
366
367
368
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());

    // Launch kernel
    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}

lijian6's avatar
lijian6 committed
383
384
385
386

template <bool kLowLatencyMode,
          int kNumRDMARanks,
          bool kCachedMode,
lijian6's avatar
lijian6 committed
387
388
          int kNumDispatchRDMASenderWarps,
          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
lijian6's avatar
lijian6 committed
389
390
391
392
393
394
395
396
397
398
399
400
401
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
        SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
        const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
        int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
        int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
        const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
        const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
        int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
        int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
        int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
        int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
        int num_ranks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
402
    enum class WarpRole {
lijian6's avatar
lijian6 committed
403
404
405
406
407
        kRDMASender,            // 从x写入到RDMA发送缓存
        kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
        kRDMAAndNVLForwarder,   // 从RDMA接收缓存转写到ipc nvl缓存
        kForwarderCoordinator,  // 向远端RDMA确认接收
        kNVLReceivers           // 从nvl缓存写入到recv_x
Chenggang Zhao's avatar
Chenggang Zhao committed
408
409
    };

lijian6's avatar
lijian6 committed
410
411
412
413
414
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
lijian6's avatar
lijian6 committed
415
416
417
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
418
419
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

lijian6's avatar
lijian6 committed
420
421
422
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
    EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);

Chenggang Zhao's avatar
Chenggang Zhao committed
423
    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
lijian6's avatar
lijian6 committed
424
425
426
427
428
429
430
431
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASender, -1};
            } else if(warp_id == kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASenderCoordinator, -1};
            }
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            if(warp_id < NUM_MAX_NVL_PEERS) {
Chenggang Zhao's avatar
Chenggang Zhao committed
432
433
434
435
436
                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
            } else {
                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
            }
        } else {
lijian6's avatar
lijian6 committed
437
            return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
Chenggang Zhao's avatar
Chenggang Zhao committed
438
439
        }
    }();
lijian6's avatar
lijian6 committed
440

lijian6's avatar
lijian6 committed
441
    auto warp_role = role_meta.first;
Chenggang Zhao's avatar
Chenggang Zhao committed
442
443
444
    auto target_rank = role_meta.second; // Not applicable for RDMA senders

    // RDMA symmetric layout
lijian6's avatar
lijian6 committed
445
446
447
448
449
450
    auto hidden_bytes             = hidden_int4 * sizeof(int4);
    auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
    auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
451
452

    // NVL buffer layouts
lijian6's avatar
lijian6 committed
453
    // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
Chenggang Zhao's avatar
Chenggang Zhao committed
454
    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
lijian6's avatar
lijian6 committed
455
    int rs_wr_rank = 0, ws_rr_rank = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
456
    if (warp_role == WarpRole::kRDMAAndNVLForwarder)
lijian6's avatar
lijian6 committed
457
        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
458
    if (warp_role == WarpRole::kNVLReceivers)
lijian6's avatar
lijian6 committed
459
        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
460
461

    // Allocate buffers
lijian6's avatar
lijian6 committed
462
463
464
465
466
467
468
469
470
    auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
    auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
471
472

    // RDMA sender warp synchronization
lijian6's avatar
lijian6 committed
473
474
475
    __shared__ volatile int rdma_send_next_token_idx;
    __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
    __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
476

lijian6's avatar
lijian6 committed
477
478
    // NVL and RDMA coordinate Forward warp synchronization
    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
479
    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
lijian6's avatar
lijian6 committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

    // Place the main logic of your kernel here, using the parameters above.
    if(warp_role == WarpRole::kRDMASender) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
        然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
        最后,它复制相关的数据到对称发送缓冲区。

        kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
        期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
        同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
        */
        // 获取任务范围
        int token_start_idx, token_end_idx;
        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

        // 清理共享内存
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
        if(warp_id == 0 && lane_id == 0) {
            rdma_send_next_token_idx = token_start_idx;
lijian6's avatar
lijian6 committed
501
        }
lijian6's avatar
lijian6 committed
502
503
504
        if(warp_id == 0 && lane_id < kNumRDMARanks) {
            rdma_send_channel_tail[lane_id]      = 0;
            rdma_send_channel_next_tail[lane_id] = 0;
lijian6's avatar
lijian6 committed
505
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
506

lijian6's avatar
lijian6 committed
507
508
509
510
511
512
        // 发送本通道中的令牌数量,通过 `-value - 1` 表示
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
        // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
        // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
        for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
            auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
513
            if (lane_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
514
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
515
            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
516
                dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
517
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
518
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
519
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
lijian6's avatar
lijian6 committed
520
                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
521
            }
lijian6's avatar
lijian6 committed
522

lijian6's avatar
lijian6 committed
523
524
525
            syncwarp();
            if (dst_rdma_rank != rdma_rank) {
                rocshmem::rocshmem_ctx_int_put_nbi_wave(
lijian6's avatar
lijian6 committed
526
527
528
                ctx, rdma_channel_meta.recv_buffer(rdma_rank),
                rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
                translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
529
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
530
        }
lijian6's avatar
lijian6 committed
531
        rocshmem::rocshmem_ctx_quiet(ctx);
lijian6's avatar
lijian6 committed
532
533
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
534

lijian6's avatar
lijian6 committed
535
        // 遍历令牌并复制到缓冲区
Chenggang Zhao's avatar
Chenggang Zhao committed
536
        int64_t token_idx;
lijian6's avatar
lijian6 committed
537
538
539
540
        int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
        for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
            // 读取RDMA秩的存在性
Chenggang Zhao's avatar
Chenggang Zhao committed
541
            uint64_t is_token_in_rank_uint64 = 0;
lijian6's avatar
lijian6 committed
542
543
544
            if(lane_id < kNumRDMARanks) {
                is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
545

lijian6's avatar
lijian6 committed
546
547
548
549
            // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
            while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
                // 等待
            }
lijian6's avatar
lijian6 committed
550
            syncwarp();
551

lijian6's avatar
lijian6 committed
552
            // 获取下一个尾部位置
lijian6's avatar
lijian6 committed
553
            int rdma_tail_idx = -1;
lijian6's avatar
lijian6 committed
554
            if(is_token_in_rank_uint64 != 0) {
lijian6's avatar
lijian6 committed
555
                rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
lijian6's avatar
lijian6 committed
556
557
558
559
560

                // 与kForwarderCoordinator相互配合,调节发送数据的频率
                while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
                    cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
561
            }
lijian6's avatar
lijian6 committed
562
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
563

lijian6's avatar
lijian6 committed
564
565
            // 存储RDMA头部以供合并
            if(lane_id < kNumRDMARanks && !kCachedMode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
566
                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
lijian6's avatar
lijian6 committed
567
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
568

lijian6's avatar
lijian6 committed
569
570
571
572
            // 更新最后一个令牌尾部
            if(last_rdma_tail_idx >= 0) {
                st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
            }
lijian6's avatar
lijian6 committed
573
574
            last_rdma_tail_idx = rdma_tail_idx;

lijian6's avatar
lijian6 committed
575
576
577
578
            // 释放顺序锁
            if(lane_id == 0) {
                rdma_send_next_token_idx += 1;
            }
lijian6's avatar
lijian6 committed
579

lijian6's avatar
lijian6 committed
580
            // 广播尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
581
            SourceMeta src_meta;
lijian6's avatar
lijian6 committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
            void* dst_send_buffers[kNumTopkRDMARanks];
            /*
            该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
            */
            #pragma unroll
            for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
                // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
                if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
                    // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
                    //  计算slot_idx在接收缓冲区中的位置
                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;

                    // 存储当前RDMA秩到topk_ranks数组中
                    topk_ranks[num_topk_ranks] = i;

                    // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
lijian6's avatar
lijian6 committed
599
                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
lijian6's avatar
lijian6 committed
600
601
602
603
                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);

                    // 如果当前lane_id等于num_topk_ranks,则更新src_meta
                    if(lane_id == num_topk_ranks) {
lijian6's avatar
lijian6 committed
604
                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
lijian6's avatar
lijian6 committed
605
606
607
608
609
                    }

                    // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
                    // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
                    dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
lijian6's avatar
lijian6 committed
610
                }
lijian6's avatar
lijian6 committed
611
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
612
613
            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);

lijian6's avatar
lijian6 committed
614
615
616
617
618
619
            // 复制 `x` 到对称发送缓冲区
            auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
620
            };
lijian6's avatar
lijian6 committed
621
622
623
624
625
            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
            }
lijian6's avatar
lijian6 committed
626

lijian6's avatar
lijian6 committed
627
628
629
630
631
632
633
634
            // 复制源元数据到对称发送缓冲区
            if(lane_id < num_topk_ranks) {
                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
            }
lijian6's avatar
lijian6 committed
635

lijian6's avatar
lijian6 committed
636
637
638
            // 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_scales; i += kWarpSize) {
lijian6's avatar
lijian6 committed
639
                auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
lijian6's avatar
lijian6 committed
640
641
642
643
644
645
646
647
648
649
650

                // auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
                // auto value = ld_nc_global(x_scales + offset);
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
                }
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
lijian6's avatar
lijian6 committed
651
            }
652

lijian6's avatar
lijian6 committed
653
654
655
            // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
Chenggang Zhao's avatar
Chenggang Zhao committed
656
                auto rank_idx = i / num_topk, copy_idx = i % num_topk;
lijian6's avatar
lijian6 committed
657
                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
658
                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
lijian6's avatar
lijian6 committed
659
660
                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
661
            }
lijian6's avatar
lijian6 committed
662
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
663

lijian6's avatar
lijian6 committed
664
665
666
667
668
669
        // 结尾部分
        // 获取顺序锁
        while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
            // 等待
        }

lijian6's avatar
lijian6 committed
670
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
671

lijian6's avatar
lijian6 committed
672
673
674
675
        // 更新最后一个令牌尾部
        if(last_rdma_tail_idx >= 0) {
            st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
676

lijian6's avatar
lijian6 committed
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        // 释放顺序锁
        if(lane_id == 0) {
            rdma_send_next_token_idx += 1;
        }
    } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
        如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
        最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。

        kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
        nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
        */
        if(warp_id > kNumDispatchRDMASenderWarps) {
            return;
        }
        // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
696

lijian6's avatar
lijian6 committed
697
698
699
        // 同步共享内存,确保所有线程在继续之前都达到了这一点
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
700

lijian6's avatar
lijian6 committed
701
        // 计算当前通道需要发送的令牌数
Chenggang Zhao's avatar
Chenggang Zhao committed
702
        int num_tokens_to_send = 0;
lijian6's avatar
lijian6 committed
703
        if(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
704
            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
lijian6's avatar
lijian6 committed
705
706
            if(channel_id > 0)
                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
707
708
        }

lijian6's avatar
lijian6 committed
709
        // 记录上次发出的尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
710
        int last_issued_tail = 0;
lijian6's avatar
lijian6 committed
711
712
713
714
715
716
717
        // 当有任何RDMA秩需要发送令牌时,继续循环
        while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
            for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
                // 计算目标RDMA秩
                int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;

                // 获取同步后的需要发送的令牌数
lijian6's avatar
lijian6 committed
718
                synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
719

lijian6's avatar
lijian6 committed
720
721
722
723
                if(synced_num_tokens_to_send == 0)
                    continue; // 如果没有令牌需要发送,则跳过

                // 读取进度
lijian6's avatar
lijian6 committed
724
                auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
lijian6's avatar
lijian6 committed
725
726
727
728
729
                auto processed_tail          = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
                auto num_tokens_processed    = processed_tail - synced_last_issued_tail;

                // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
                if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
730
731
                    continue;

lijian6's avatar
lijian6 committed
732
733
734
735
736
737
                // 计算本次需要发出的令牌数
                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);

                // 发出RDMA发送请求
                if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
738
                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
739
                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
740
741
742
743
744
745
746
747
748
                    rocshmem::rocshmem_ctx_schar_put_nbi_wave(
                        ctx,
                        rdma_channel_data.recv_buffer(rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        rdma_channel_data.send_buffer(dst_rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        num_bytes_per_rdma_token * num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
                    rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
749
                } else {
lijian6's avatar
lijian6 committed
750
                    // 对于本地RDMA秩,使用较轻的内存屏障
Chenggang Zhao's avatar
Chenggang Zhao committed
751
752
753
                    memory_fence();
                }

lijian6's avatar
lijian6 committed
754
                // 更新尾部位置
lijian6's avatar
lijian6 committed
755
                syncwarp();
lijian6's avatar
lijian6 committed
756
                if(lane_id == dst_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
757
758
                    last_issued_tail += num_tokens_to_issue;
                    num_tokens_to_send -= num_tokens_to_issue;
lijian6's avatar
lijian6 committed
759
                    // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
lijian6's avatar
lijian6 committed
760
761
762
                    rocshmem::rocshmem_ctx_ulong_atomic_add(
                        ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
763
764
                }
            }
lijian6's avatar
lijian6 committed
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        } // while(__any(num_tokens_to_send > 0))
    } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
        它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
        接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
        然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
        最后,它同步头部和尾部索引,并标记通道为退役状态。
        */
        // RDMA消费者和NVL生产者
        const auto dst_nvl_rank          = target_rank;                                       // 目标NVL秩
        const auto dst_rank              = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;      // 目标秩
        const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);              // 目标秩专家开始
        const auto dst_rank_expert_end   = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束

        // 等待计数器到达
Chenggang Zhao's avatar
Chenggang Zhao committed
781
        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
lijian6's avatar
lijian6 committed
782
783
        EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
784
785
786
787
788
789
790
791
792
        if(lane_id < kNumRDMARanks) {
            while(true) {
                // 对应于kRDMASender中的数据写入
                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);                     // 是nvl节点的起始地址
                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);            // 本rdma节点的起始地址
                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);        // 本节点的结束地址
                if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
                    // 通知NVL秩
Chenggang Zhao's avatar
Chenggang Zhao committed
793
                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
lijian6's avatar
lijian6 committed
794
795
796
                    EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);

                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
797
798
                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);

lijian6's avatar
lijian6 committed
799
800
                    // 保存从RDMA通道接收的令牌计数
                    src_rdma_channel_prefix = -meta_2 - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
801
                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;
lijian6's avatar
lijian6 committed
802
803
804
805
806
                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
                    if(!kCachedMode)
                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;

                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
Chenggang Zhao's avatar
Chenggang Zhao committed
807
808
809
810
                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
                    break;
                }

lijian6's avatar
lijian6 committed
811
812
813
814
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
                           channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
Chenggang Zhao's avatar
Chenggang Zhao committed
815
816
817
818
                    trap();
                }
            }
        }
lijian6's avatar
lijian6 committed
819
        syncwarp();
lijian6's avatar
lijian6 committed
820
821

        // 移动缓存的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
822
823
        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;

lijian6's avatar
lijian6 committed
824
825
826
        // 等待共享内存被清理
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
827

lijian6's avatar
lijian6 committed
828
829
830
831
        // 开始准备处理接受数据,直到所有的数据接受完成。
        // 转发从RDMA缓冲区的令牌
        // 注意:总是从本地秩开始
        int src_rdma_rank = sm_id % kNumRDMARanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
832
833
        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
lijian6's avatar
lijian6 committed
834
835
        while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
            // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
lijian6's avatar
lijian6 committed
836
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
837
838
839

            // 用于给kNVLReceivers进行互动,控制数据的传输速度
            while(lane_id == 0) {
lijian6's avatar
lijian6 committed
840
                int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
lijian6's avatar
lijian6 committed
841
                if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
842
                    break;
lijian6's avatar
lijian6 committed
843
                cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
844

lijian6's avatar
lijian6 committed
845
846
847
848
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
849
850
851
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
852
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
853

lijian6's avatar
lijian6 committed
854
            // 找到下一个源RDMA秩(轮询)
lijian6's avatar
lijian6 committed
855
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
856
            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
857
                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
lijian6's avatar
lijian6 committed
858
859
860
861
862
                if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
                    if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));

                    if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
863
                        break;
lijian6's avatar
lijian6 committed
864
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
865
866
                }

lijian6's avatar
lijian6 committed
867
868
869
870
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
Chenggang Zhao's avatar
Chenggang Zhao committed
871
872
873
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
874

lijian6's avatar
lijian6 committed
875
876
            auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
            auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
877

lijian6's avatar
lijian6 committed
878
879
880
881
882
883
884
885
886
887
            // 遍历RDMA缓冲区中的每一个令牌
            for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
                // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
                void* shifted           = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
                auto src_meta           = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
                if(lane_id == src_rdma_rank) {
                    num_tokens_to_recv_from_rdma -= 1;
                }

Chenggang Zhao's avatar
Chenggang Zhao committed
888
                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
lijian6's avatar
lijian6 committed
889
                if(lane_id == src_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
890
891
                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
                    rdma_nvl_token_idx += is_in_dst_nvl_rank;
lijian6's avatar
lijian6 committed
892
                    if(!kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
893
894
                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
                }
lijian6's avatar
lijian6 committed
895
896

                if(!is_in_dst_nvl_rank)
Chenggang Zhao's avatar
Chenggang Zhao committed
897
898
                    continue;

lijian6's avatar
lijian6 committed
899
                // 获取一个空闲槽位
lijian6's avatar
lijian6 committed
900
                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
901

lijian6's avatar
lijian6 committed
902
                // 复制数据
lijian6's avatar
lijian6 committed
903
904
                UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
                                   nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
lijian6's avatar
lijian6 committed
905
906
907
                                   reinterpret_cast<int4*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
lijian6's avatar
lijian6 committed
908

lijian6's avatar
lijian6 committed
909
910
                // 复制源元数据
                if(lane_id == 0)
lijian6's avatar
lijian6 committed
911
                    st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
lijian6's avatar
lijian6 committed
912
                shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
lijian6's avatar
lijian6 committed
913

lijian6's avatar
lijian6 committed
914
                // 复制 `x_scales`
lijian6's avatar
lijian6 committed
915
916
                UNROLLED_WARP_COPY(1, lane_id, num_scales,
                                   nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
lijian6's avatar
lijian6 committed
917
918
919
920
921
922
923
924
925
926
927
928
929
930
                                   reinterpret_cast<float*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<float*>(shifted) + num_scales;

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
                    // 读取
                    auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
                    shifted = reinterpret_cast<int*>(shifted) + num_topk;
                    auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);

                    // 转换和写入
                    idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
                    st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
lijian6's avatar
lijian6 committed
931
                    weight_value = idx_value >= 0 ? weight_value : 0.0f;
lijian6's avatar
lijian6 committed
932
                    st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
933
934
                }

lijian6's avatar
lijian6 committed
935
936
                // 在NVL缓冲区不足的情况下,提前停止
                if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
937
938
939
                    src_rdma_tail = i + 1;
            }

lijian6's avatar
lijian6 committed
940
941
942
            // 同步头部索引
            if(lane_id == src_rdma_rank)
                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
943

lijian6's avatar
lijian6 committed
944
            // 移动尾部索引,与kNVLReceivers互相通信使用
lijian6's avatar
lijian6 committed
945
            syncwarp();
lijian6's avatar
lijian6 committed
946
947
948
            if(lane_id == 0) {
                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
949
950
951
        }

        // Retired
lijian6's avatar
lijian6 committed
952
        syncwarp();
lijian6's avatar
lijian6 committed
953
        if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
954
            forward_channel_retired[dst_nvl_rank] = true;
lijian6's avatar
lijian6 committed
955
956
957
958
959
960
961
962
963
        }
    } else if(warp_role == WarpRole::kForwarderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
        它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
        然后,它清理共享内存,并初始化转发通道的头部和退役状态。
        接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
        否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
        */
Chenggang Zhao's avatar
Chenggang Zhao committed
964
        // Extra warps for forwarder coordinator should exit directly
lijian6's avatar
lijian6 committed
965
        if (warp_id > NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
966
967
            return;

lijian6's avatar
lijian6 committed
968
969
970
971
972
973
        // 转发warp协调器
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        // 清理共享内存
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
        for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
974
            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
lijian6's avatar
lijian6 committed
975
        if(lane_id < NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
976
            forward_channel_retired[lane_id] = false;
lijian6's avatar
lijian6 committed
977
978
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
979
980

        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
lijian6's avatar
lijian6 committed
981
982
983

        while(true) {
            // 找到最小的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
984
            int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
985
#pragma unroll
lijian6's avatar
lijian6 committed
986
987
            for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                if(!forward_channel_retired[i])
lijian6's avatar
lijian6 committed
988
                    min_head = min(min_head, forward_channel_head[i][target_rdma]);
lijian6's avatar
lijian6 committed
989
990

            if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
Chenggang Zhao's avatar
Chenggang Zhao committed
991
                break;
lijian6's avatar
lijian6 committed
992
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
993

lijian6's avatar
lijian6 committed
994
995
            // 更新远程头部
            if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
lijian6's avatar
lijian6 committed
996
997
998
                rocshmem::rocshmem_ctx_ulong_atomic_add(
                    ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
                    translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
999
1000
                last_head = min_head;
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1001

lijian6's avatar
lijian6 committed
1002
            // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1003
            __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1004
        }
lijian6's avatar
lijian6 committed
1005
1006
1007
1008
1009
1010
1011
1012
    } else if(warp_role == WarpRole::kNVLReceivers) {
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }

        // Place the main logic of your kernel here, using the parameters above.
        // NVL消费者
        // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
Chenggang Zhao's avatar
Chenggang Zhao committed
1013
        int src_nvl_rank = target_rank, total_offset = 0;
lijian6's avatar
lijian6 committed
1014
1015
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1016
1017
            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];

lijian6's avatar
lijian6 committed
1018
1019
        // 接收通道偏移
        int start_offset = 0, end_offset = 0, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
1020
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1021
1022

        while(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1023
            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1024
            end_offset   = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1025
            if(start_offset < 0 && end_offset < 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1026
1027
1028
1029
                start_offset = -start_offset - 1, end_offset = -end_offset - 1;
                total_offset += start_offset;
                break;
            }
lijian6's avatar
lijian6 committed
1030
1031
1032
1033
            // 超时检查
            if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
                        channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
1034
1035
1036
                trap();
            }
        }
lijian6's avatar
lijian6 committed
1037

Chenggang Zhao's avatar
Chenggang Zhao committed
1038
1039
        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

lijian6's avatar
lijian6 committed
1040
1041
1042
        // 保存以供合并使用
        if(lane_id < kNumRDMARanks && !kCachedMode)
            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
lijian6's avatar
lijian6 committed
1043
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1044
1045

        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1046
1047
        while(num_tokens_to_recv > 0) {
            // 通过通道0检查通道状态
lijian6's avatar
lijian6 committed
1048
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1049
1050
1051
            while(lane_id == 0) {
                // 准备复制
                if(cached_channel_head_idx != cached_channel_tail_idx)
Chenggang Zhao's avatar
Chenggang Zhao committed
1052
                    break;
lijian6's avatar
lijian6 committed
1053
1054
1055
1056
1057
                cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
                            channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1058
1059
1060
1061
                    trap();
                }
            }

lijian6's avatar
lijian6 committed
1062
            // 同步队列尾部
lijian6's avatar
lijian6 committed
1063
1064
            cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);

lijian6's avatar
lijian6 committed
1065
            // 复制数据
Chenggang Zhao's avatar
Chenggang Zhao committed
1066
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1067
1068
1069
1070
            for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
                auto meta               = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
                int64_t recv_token_idx  = shfl_sync(total_offset, meta.src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1071
1072
                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;

lijian6's avatar
lijian6 committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
                // 复制数据
                UNROLLED_WARP_COPY(5,
                                lane_id,
                                hidden_int4,
                                recv_x + recv_token_idx * hidden_int4,
                                nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
                                ld_nc_global,
                                st_na_global);

                // 复制源元数据
                if(lane_id == 0 && !kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
1084
                    st_na_global(recv_src_meta + recv_token_idx, meta);
lijian6's avatar
lijian6 committed
1085

lijian6's avatar
lijian6 committed
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                // 复制比例
                UNROLLED_WARP_COPY(1,
                                lane_id,
                                num_scales,
                                recv_x_scales + recv_token_idx * num_scales,
                                nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
                                ld_nc_global,
                                st_na_global);

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
lijian6's avatar
lijian6 committed
1097
1098
                    auto recv_idx   = recv_token_idx * num_topk + lane_id;
                    auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
lijian6's avatar
lijian6 committed
1099
1100
                    st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
                    st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1101
1102
1103
                }
            }

lijian6's avatar
lijian6 committed
1104
            // 移动队列
lijian6's avatar
lijian6 committed
1105
            syncwarp();
lijian6's avatar
lijian6 committed
1106
            if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1107
                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
lijian6's avatar
lijian6 committed
1108
1109
            }
        } // while(num_tokens_to_recv > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1110
    }
lijian6's avatar
lijian6 committed
1111
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1112
1113
}

lijian6's avatar
lijian6 committed
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1127
    constexpr int kNumDispatchRDMASenderWarps = 7;
lijian6's avatar
lijian6 committed
1128
1129
    // Make sure never OOB
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
lijian6's avatar
lijian6 committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154

#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                       \
    {                                                                                              \
        auto dispatch_func =                                                                       \
            low_latency_mode                                                                       \
                ? (is_cached_dispatch                                                              \
                       ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>         \
                       : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>)       \
                : (is_cached_dispatch                                                              \
                       ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>        \
                       : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>);     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx,   \
            recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta),                      \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head,   \
            send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix,        \
            rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix,      \
            recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales,       \
            num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,       \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
1155
1156
    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

lijian6's avatar
lijian6 committed
1157
1158
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
                        (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1159
1160
1161
1162
    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
1163
template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
1164
__global__ void __launch_bounds__(1024, 1)
lijian6's avatar
lijian6 committed
1165
1166
1167
1168
1169
1170
1171
1172
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
              const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
              int num_channels, const int *rdma_channel_prefix_matrix,
              const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
              void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
              bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id       = static_cast<int>(blockIdx.x);
    auto thread_id   = static_cast<int>(threadIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
1173
    auto num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
1174
1175
1176
    auto num_warps   = num_threads / kWarpSize;
    auto warp_id     = thread_id / kWarpSize;
    auto lane_id     = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
1177

lijian6's avatar
lijian6 committed
1178
    auto nvl_rank       = rank % NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1179
1180
1181
1182
1183
    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Using two SMs, which clean the RDMA/NVL buffer respectively
    if (sm_id == 0) {
        // Barrier for RDMA
lijian6's avatar
lijian6 committed
1184
1185
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Shangyan Zhou's avatar
Fix  
Shangyan Zhou committed
1186

lijian6's avatar
lijian6 committed
1187
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1188

lijian6's avatar
lijian6 committed
1189
1190
        // Clean
        auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
1191
1192
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1193
1194
1195
1196
1197
1198
        rocshmem::rocshmem_fence();
        __syncthreads();

        // Barrier again
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1199

lijian6's avatar
lijian6 committed
1200
1201
1202
1203
1204
1205
1206
    } else if (sm_id == 1) {
        // Barrier for NVL
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
        __syncthreads();

        // Clean
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
1207
1208
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1209
        memory_fence();
1210
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1211
1212

        // Barrier again
1213
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
1214
    } else if (sm_id == 2) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1215
1216
1217
1218
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1219
        EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
1220
1221
1222
1223

        // Iterate in reverse order
        if (lane_id < num_rdma_ranks and warp_id < num_channels) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1224
1225
            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1226
1227
1228

            // NOTES: `1 << 25` is a heuristic large number
            int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1229
1230
1231
            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                auto current_head =
                    __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
                if (current_head < 0) {
                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
                } else {
                    last_head = current_head;
                }
            }
        }
    } else {
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1244
1245
1246
        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
                                  rdma_rank_prefix_sum != nullptr);
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
1247

lijian6's avatar
lijian6 committed
1248
1249
1250
        if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
            for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks;
                 dst_rdma_rank += num_channels * 2 - 3) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1251
                // Iterate in reverse order
lijian6's avatar
lijian6 committed
1252
1253
1254
1255
1256
1257
                int token_start_idx =
                    warp_id == 0
                        ? 0
                        : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
                int token_end_idx =
                    rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
1258
1259
1260
1261
1262
                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
                token_start_idx += shift, token_end_idx += shift;

                // NOTES: `1 << 25` is a heuristic large number
                int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1263
1264
1265
1266
1267
1268
1269
                for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                    auto current_head =
                        __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
                    if (current_head < 0) {
                        combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
                    } else {
                        last_head = current_head;
1270
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1271
1272
1273
1274
1275
1276
1277
                }
            }
        }
    }
}

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
1278
1279
1280
1281
1282
1283
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
1284
                   bool is_cached_dispatch, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1285
    const int  num_threads    = ::max(128, kWarpSize * num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1286
1287
1288
    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Get clean meta
lijian6's avatar
lijian6 committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
1299
1300
1301
1302
1303
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_channels * 2 > 3);

    // Launch kernel
lijian6's avatar
lijian6 committed
1304
    auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
Chenggang Zhao's avatar
Chenggang Zhao committed
1305
    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
lijian6's avatar
lijian6 committed
1306
1307
1308
1309
1310
1311
    LAUNCH_KERNEL_NON_COOPERATIVE(
        &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
        nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
        num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
        rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
        cpu_rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1312
1313
}

lijian6's avatar
lijian6 committed
1314
1315
1316
1317
1318
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
                             int lane_id, int hidden_int4, int num_topk,
                             int4* combined_row, float* combined_topk_weights,
                             int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1319
1320
1321
1322
    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);

    // Broadcast current heads
    // Lane `i` holds the head of rank `i` and `is_token_in_rank`
lijian6's avatar
lijian6 committed
1323
    EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
Chenggang Zhao's avatar
Chenggang Zhao committed
1324
    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
lijian6's avatar
lijian6 committed
1325
1326
1327
1328
1329
    #pragma unroll
    for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
        slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
        topk_ranks[num_topk_ranks ++] = i;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1330
1331
1332
    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);

    // Reduce data
lijian6's avatar
lijian6 committed
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
    #pragma unroll
    for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
        // Read buffers
        // TODO: maybe too many registers here
        int4 recv_value_int4[kMaxNumRanks];
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j)
            recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);

        // Reduce all-to-all results
lijian6's avatar
lijian6 committed
1343
        float values[kDtypePerInt4] = {0};
lijian6's avatar
lijian6 committed
1344
1345
1346
1347
1348
1349
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j) {
            auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
            #pragma unroll
            for (int k = 0; k < kDtypePerInt4; ++ k)
                values[k] += static_cast<float>(recv_value_dtypes[k]);
Shangyan Zhou's avatar
Shangyan Zhou committed
1350
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1351

lijian6's avatar
lijian6 committed
1352
1353
1354
1355
1356
        // Cast back to `dtype_t` and write
        int4 out_int4;
        auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
        #pragma unroll
        for (int j = 0; j < kDtypePerInt4; ++ j)
lijian6's avatar
lijian6 committed
1357
1358
            out_dtypes[j] = static_cast<dtype_t>(values[j]);
        st_na_global(combined_row + i, out_int4);
Chenggang Zhao's avatar
Chenggang Zhao committed
1359
1360
1361
1362
1363
    }

    // Reduce `topk_weights`
    if (lane_id < num_topk) {
        float value = 0;
lijian6's avatar
lijian6 committed
1364
1365
        #pragma unroll
        for (int i = 0; i < num_topk_ranks; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1366
1367
1368
1369
1370
1371
1372
1373
            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
        st_na_global(combined_topk_weights + lane_id, value);
    }

    // Return the minimum top-k rank
    return topk_ranks[0];
}

lijian6's avatar
lijian6 committed
1374
1375
1376
1377
template <bool kLowLatencyMode,
          int kNumRDMARanks,
          typename dtype_t,
          int kNumCombineForwarderWarps,
lijian6's avatar
lijian6 committed
1378
          int kNumTopkRDMARanks     = get_num_topk_rdma_ranks(kNumRDMARanks),
lijian6's avatar
lijian6 committed
1379
          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
lijian6's avatar
lijian6 committed
1380
          int kNumForwarders        = kNumRDMARanks * kNumWarpsPerForwarder,
lijian6's avatar
lijian6 committed
1381
1382
1383
          int kNumRDMAReceivers     = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) 
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
lijian6's avatar
lijian6 committed
1384
1385
1386
1387
1388
1389
1390
1391
            const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
            const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
            const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
            const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
            int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
            int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
            int num_ranks) {
lijian6's avatar
lijian6 committed
1392
1393
1394
1395
1396
1397
1398
    enum class WarpRole {
        kNVLSender,
        kNVLAndRDMAForwarder,
        kRDMAReceiver,
        kRDMACoordinator,
        kNVLCoordinator
    };
Chenggang Zhao's avatar
Chenggang Zhao committed
1399

lijian6's avatar
lijian6 committed
1400
1401
1402
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

lijian6's avatar
lijian6 committed
1403
1404
1405
1406
1407
1408
    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

Chenggang Zhao's avatar
Chenggang Zhao committed
1409
1410
1411
1412
    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));

    // NOTES: we decouple a channel into 2 SMs
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
1413
1414
1415
1416
1417
1418
1419

    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumForwarders) {
                return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
Chenggang Zhao's avatar
Chenggang Zhao committed
1420
            } else {
lijian6's avatar
lijian6 committed
1421
                return {WarpRole::kRDMACoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1422
1423
            }
        } else {
lijian6's avatar
lijian6 committed
1424
            if(warp_id < kNumForwarders) {
lijian6's avatar
lijian6 committed
1425
                return {WarpRole::kRDMAReceiver, warp_id};
Chenggang Zhao's avatar
Chenggang Zhao committed
1426
            } else {
lijian6's avatar
lijian6 committed
1427
                return {WarpRole::kNVLCoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1428
1429
1430
            }
        }
    }();
lijian6's avatar
lijian6 committed
1431

Chenggang Zhao's avatar
Chenggang Zhao committed
1432
    auto warp_role = role_meta.first;
lijian6's avatar
lijian6 committed
1433
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
Chenggang Zhao's avatar
Chenggang Zhao committed
1434

lijian6's avatar
lijian6 committed
1435
    EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
1436
1437
    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1438
    // This approach is designed to sync multiple warps in a loop
lijian6's avatar
lijian6 committed
1439
1440
1441
1442
    constexpr int num_sync_large_iteration = 64;
    constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
    __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];   
    for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
lijian6's avatar
lijian6 committed
1443
1444
1445
1446
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
1447
    if (warp_role == WarpRole::kNVLSender) {
lijian6's avatar
lijian6 committed
1448
1449
1450
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1451

lijian6's avatar
lijian6 committed
1452
        const auto dst_nvl_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
1453
1454
1455
        // NVL layouts
        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
1456
1457
1458
1459
1460
        auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
1461

Chenggang Zhao's avatar
Chenggang Zhao committed
1462
1463
        // Get tasks for each RDMA lane
        int token_start_idx = 0, token_end_idx = 0;
lijian6's avatar
lijian6 committed
1464
1465
        if(lane_id < kNumRDMARanks) {
            int prefix_idx  = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
Chenggang Zhao's avatar
Chenggang Zhao committed
1466
            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
lijian6's avatar
lijian6 committed
1467
            token_end_idx   = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1468
        }
lijian6's avatar
lijian6 committed
1469
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1470
1471
1472

        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1473
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1474
1475

        // Iterate over all tokens and send by chunks
lijian6's avatar
lijian6 committed
1476
        while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1477
            // Exit if possible
lijian6's avatar
lijian6 committed
1478
            if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1479
1480
                break;

lijian6's avatar
lijian6 committed
1481
            // Decide next RDMA buffer to send
Chenggang Zhao's avatar
Chenggang Zhao committed
1482
            bool is_lane_ready = false;
lijian6's avatar
lijian6 committed
1483
            auto start_time    = wall_clock64();
lijian6's avatar
lijian6 committed
1484
1485

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1486
                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1487
                is_lane_ready      = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
lijian6's avatar
lijian6 committed
1488
1489
1490
                                num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;

                if(__any_sync(kFullWarpMask, is_lane_ready))
Chenggang Zhao's avatar
Chenggang Zhao committed
1491
                    break;
lijian6's avatar
lijian6 committed
1492

Chenggang Zhao's avatar
Chenggang Zhao committed
1493
                // Retry
lijian6's avatar
lijian6 committed
1494
1495
                if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1496
1497

                // Timeout check
lijian6's avatar
lijian6 committed
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
                if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
                        "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
                        channel_id,
                        rdma_rank,
                        nvl_rank,
                        dst_nvl_rank,
                        lane_id,
                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),
                        cached_channel_tail_idx,
                        token_start_idx,
                        token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1510
1511
1512
1513
1514
                    trap();
                }
            }

            // Sync token start index and count
lijian6's avatar
lijian6 committed
1515
1516
            for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
                if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1517
1518
1519
                    continue;

                // Sync token start index
lijian6's avatar
lijian6 committed
1520
1521
                auto token_idx          = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
                int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1522
1523

                // Send by chunk
lijian6's avatar
lijian6 committed
1524
                for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1525
1526
                    // Get an empty slot
                    int dst_slot_idx = 0;
lijian6's avatar
lijian6 committed
1527
1528
1529
                    if(lane_id == current_rdma_idx) {
                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
1530
                    }
lijian6's avatar
lijian6 committed
1531
                    dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
lijian6's avatar
lijian6 committed
1532
1533
1534
1535

                    // Copy data
                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                    auto shifted_x         = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
1536
                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
1537

lijian6's avatar
lijian6 committed
1538
                    // Copy source meta
lijian6's avatar
lijian6 committed
1539
1540
                    if(lane_id == 0)
                        st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1541

lijian6's avatar
lijian6 committed
1542
                    // Copy `topk_weights`
lijian6's avatar
lijian6 committed
1543
1544
1545
                    if(lane_id < num_topk)
                        st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
                                    ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1546
1547
1548
1549
1550
                }
                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
            }

            // Move queue tail
lijian6's avatar
lijian6 committed
1551
            syncwarp();
lijian6's avatar
lijian6 committed
1552
1553
1554
            if(lane_id < kNumRDMARanks and is_lane_ready) {
                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1555
1556
        }
    } else {
lijian6's avatar
lijian6 committed
1557
1558
1559
1560
        if(warp_id > kNumForwarders) {
            return;
        }

Chenggang Zhao's avatar
Chenggang Zhao committed
1561
1562
        // Combiners and coordinators
        // RDMA symmetric layout
lijian6's avatar
lijian6 committed
1563
        auto hidden_bytes = hidden_int4 * sizeof(int4);
lijian6's avatar
lijian6 committed
1564
        auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
lijian6's avatar
lijian6 committed
1565
1566
1567
        auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1568
1569

        // NVL layouts
lijian6's avatar
lijian6 committed
1570
1571
1572
1573
        void* local_nvl_buffer = buffer_ptrs[nvl_rank];
        void* nvl_buffers[NUM_MAX_NVL_PEERS];
        #pragma unroll
        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1574
            nvl_buffers[i] = buffer_ptrs[i];
lijian6's avatar
lijian6 committed
1575
1576
1577
1578
1579
        auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
Chenggang Zhao's avatar
Chenggang Zhao committed
1580
1581

        // Combiner warp synchronization
lijian6's avatar
lijian6 committed
1582
        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
1583
        __shared__ volatile bool forwarder_retired[kNumForwarders];
lijian6's avatar
lijian6 committed
1584
        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
1585
        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
lijian6's avatar
lijian6 committed
1586

Chenggang Zhao's avatar
Chenggang Zhao committed
1587
1588
1589
        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
            // Receive from NVL ranks and forward to RDMA ranks
            // NOTES: this part is using "large warps" for each RDMA ranks
lijian6's avatar
lijian6 committed
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
            const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
            const auto sub_warp_id   = target_rank % kNumWarpsPerForwarder;
            auto send_buffer         = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
            // auto sync_large_warp     = [=]() {
            //     if(kNumWarpsPerForwarder == 1) {
            //         syncwarp();
            //     } else {
            //         // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
            //         // __syncthreads();
            //         syncwarp();
            //     }
            // };
            auto sync_large_warp = [=](const int iter, const int mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1603
                if (kNumWarpsPerForwarder == 1) {
lijian6's avatar
lijian6 committed
1604
                    syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1605
                } else {
lijian6's avatar
lijian6 committed
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
                        // LDS index to store for sync
                        int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        //reset index in the LDS to avoid race condition due to warp scheduling
                        int reset_idx =         dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        auto start_time = wall_clock64();
                        if (lane_id == 0){
                            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
                        }
                        syncwarp();
                        //The while(...) loop polls the counter until all warps have arrived
                        if (lane_id == 0){
                            while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
                                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                                    printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
                                    trap();
                                }
lijian6's avatar
lijian6 committed
1622
1623
                            }
                        }
lijian6's avatar
lijian6 committed
1624
1625
1626
1627
1628
                        syncwarp();
                        if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
                            sync_large_warp_counters[reset_idx] = 0;
                        }
                        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1629
1630
                }
            };
lijian6's avatar
lijian6 committed
1631
            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
1632

lijian6's avatar
lijian6 committed
1633
1634
            // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
lijian6's avatar
lijian6 committed
1635
            nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
lijian6's avatar
lijian6 committed
1636
            nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
1637
1638
1639
1640
            nvl_channel_head.advance(dst_rdma_rank);
            nvl_channel_tail.advance(dst_rdma_rank);

            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1641
1642
1643
1644
1645
            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
            // sync_forwarder_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1646
1647
1648

            // Get count and cached head
            int cached_nvl_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1649
1650
            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1651
1652
1653
1654
1655
            num_tokens_to_combine -= num_tokens_prefix;
            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;

            // Iterate over all tokens and combine by chunks
lijian6's avatar
lijian6 committed
1656
            for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1657
                // Check destination queue emptiness, or wait a buffer to be released
lijian6's avatar
lijian6 committed
1658
                auto token_end_idx      = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
Chenggang Zhao's avatar
Chenggang Zhao committed
1659
                auto num_chunked_tokens = token_end_idx - token_start_idx;
lijian6's avatar
lijian6 committed
1660
                auto start_time         = wall_clock64();
lijian6's avatar
lijian6 committed
1661
1662
1663
1664
1665
1666
                while(sub_warp_id == 0 and lane_id == 0) {
                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
                    // Here, `token_start_idx` is the actual tail
                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));

                    if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
1667
1668
1669
                        break;

                    // Timeout check
lijian6's avatar
lijian6 committed
1670
1671
1672
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
                                channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
1673
1674
1675
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1676
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1677
                sync_large_warp(token_start_idx, 0);
lijian6's avatar
lijian6 committed
1678

Chenggang Zhao's avatar
Chenggang Zhao committed
1679
                // Combine and write to the RDMA buffer
lijian6's avatar
lijian6 committed
1680
                for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1681
                    // Read expected head
lijian6's avatar
lijian6 committed
1682
                    EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1683
                    int expected_head = -1;
lijian6's avatar
lijian6 committed
1684
1685
                    if(lane_id < NUM_MAX_NVL_PEERS)
                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1686
1687

                    // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1688
                    start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1689
1690
                    while(cached_nvl_channel_tail_idx <= expected_head) {
                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1691
1692

                        // Timeout check
lijian6's avatar
lijian6 committed
1693
1694
1695
                        if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
                            printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
                                    channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1696
1697
1698
1699
1700
                            trap();
                        }
                    }

                    // Combine current token
lijian6's avatar
lijian6 committed
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
                    auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
                    combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
                                                                                    expected_head, lane_id,
                                                                                    hidden_int4, num_topk,
                                                                                    reinterpret_cast<int4*>(shifted),
                                                                                    reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
                                                                                    num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1711
1712

                    // Update head
lijian6's avatar
lijian6 committed
1713
1714
1715
1716
                    if(lane_id < NUM_MAX_NVL_PEERS) {
                        expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
                                        : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1717
                }
lijian6's avatar
lijian6 committed
1718
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1719
                sync_large_warp(token_start_idx, 1);
lijian6's avatar
lijian6 committed
1720

Chenggang Zhao's avatar
Chenggang Zhao committed
1721
                // Issue RDMA send
lijian6's avatar
lijian6 committed
1722
1723
                if(sub_warp_id == kNumWarpsPerForwarder - 1) {
                    if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1724
                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
1725
                        rocshmem::rocshmem_ctx_schar_put_nbi_wave(
lijian6's avatar
lijian6 committed
1726
1727
1728
1729
1730
1731
1732
1733
1734
                            ctx,
                            rdma_channel_data.recv_buffer(rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            rdma_channel_data.send_buffer(dst_rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            num_chunked_tokens * num_bytes_per_rdma_token,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

                        rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1735
1736
1737
1738
1739
                    } else {
                        memory_fence();
                    }

                    // Write new RDMA tail
lijian6's avatar
lijian6 committed
1740
                    syncwarp();
lijian6's avatar
lijian6 committed
1741
                    if(lane_id == 0) {
lijian6's avatar
lijian6 committed
1742
1743
1744
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1745
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1746
1747
1748
1749
                }
            }

            // Retired
lijian6's avatar
lijian6 committed
1750
            syncwarp();
lijian6's avatar
lijian6 committed
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
            if(lane_id == 0) {
                forwarder_retired[target_rank] = true;
            }
        } else if (warp_role == WarpRole::kRDMACoordinator) {
            // Coordinator
            // Sync shared memory status
            // sync_forwarder_smem();
            __syncthreads();
            constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

            int last_nvl_head[kNumRDMARanks] = {0};
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
                // Retired
                if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
                    break;

                {
                    // Find minimum head for NVL ranks
                    #pragma unroll
                    for(int i = 0; i < kNumRDMARanks; ++i) {
                        int min_head = std::numeric_limits<int>::max();
                        #pragma unroll
                        for(int j = 0; j < num_warps_per_rdma_rank; ++j)
                            if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);

                        if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
                        }
                    }
                }

                // Nanosleep and let other warps work
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
            }
        } else if(warp_role == WarpRole::kRDMAReceiver) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1790
1791
            // Receive from RDMA ranks and write to the output tensor
            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1792
            EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lijian6's avatar
lijian6 committed
1793
1794
1795
1796
            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1797
1798
1799

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1800
            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1801
1802
1803

            // Iterate over all tokens and combine
            int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1804
            for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1805
                // Read expected head
lijian6's avatar
lijian6 committed
1806
                EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1807
                int expected_head = -1;
lijian6's avatar
lijian6 committed
1808
1809
1810
1811
                if(lane_id < kNumRDMARanks) {
                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
                    (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
                                        : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1812
1813
1814
                }

                // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1815
                auto start_time = wall_clock64();
Chenggang Zhao's avatar
Chenggang Zhao committed
1816
                while (cached_channel_tail_idx <= expected_head) {
lijian6's avatar
lijian6 committed
1817
                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
Chenggang Zhao's avatar
Chenggang Zhao committed
1818
1819

                    // Timeout check
lijian6's avatar
lijian6 committed
1820
1821
1822
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
                                channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1823
1824
1825
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1826
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1827
1828

                // Combine current token
lijian6's avatar
lijian6 committed
1829
1830
1831
1832
1833
1834
1835
1836
                auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
                combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
                                                                            expected_head, lane_id,
                                                                            hidden_int4, num_topk,
                                                                            combined_x + token_idx * hidden_int4,
                                                                            combined_topk_weights + token_idx * num_topk,
                                                                            num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1837
1838
1839
            }

            // Retired
lijian6's avatar
lijian6 committed
1840
            syncwarp();
lijian6's avatar
lijian6 committed
1841
1842
1843
1844
            if(lane_id == 0) {
                rdma_receiver_retired[target_rank] = true;
            }
        } else if(warp_role == WarpRole::kNVLCoordinator) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1845
1846
            // Coordinator
            // Sync shared memory status
lijian6's avatar
lijian6 committed
1847
1848
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1849
1850
            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1851
            int last_rdma_head               = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
1852
            int last_nvl_head[kNumRDMARanks] = {0};
lijian6's avatar
lijian6 committed
1853
1854
            int dst_rdma_rank                = lane_id < kNumRDMARanks ? lane_id : 0;
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
lijian6's avatar
lijian6 committed
1855
1856
1857
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1858
                // Retired
lijian6's avatar
lijian6 committed
1859
                if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
Chenggang Zhao's avatar
Chenggang Zhao committed
1860
1861
1862
                    break;

                // Find minimum head for RDMA ranks
lijian6's avatar
lijian6 committed
1863
                {
Chenggang Zhao's avatar
Chenggang Zhao committed
1864
                    int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1865
1866
1867
    #pragma unroll
                    for(int i = 0; i < kNumRDMAReceivers; ++i)
                        if(not rdma_receiver_retired[i])
lijian6's avatar
lijian6 committed
1868
                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
lijian6's avatar
lijian6 committed
1869
1870

                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lijian6's avatar
lijian6 committed
1871
1872
1873
1874
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

1875
1876
                        last_rdma_head = min_head;
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1877
1878
1879
                }

                // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1880
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1881
1882
1883
            }
        }
    }
lijian6's avatar
lijian6 committed
1884
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1885
1886
}

lijian6's avatar
lijian6 committed
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1897
    constexpr int kNumCombineForwarderWarps = 8;
lijian6's avatar
lijian6 committed
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916

#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                                        \
    {                                                                                              \
        auto combine_func =                                                                        \
            low_latency_mode                                                                       \
                ? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>           \
                : combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>;         \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights,       \
            is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights,            \
            reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1),        \
            combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,           \
            num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr,                    \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

lijian6's avatar
lijian6 committed
1917
    int num_rdma_ranks           = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1918
    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
lijian6's avatar
lijian6 committed
1919
1920
    int num_forwarder_warps      = num_rdma_ranks * num_warps_per_forwarder;
    EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
lijian6's avatar
lijian6 committed
1921
    EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
1922
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
lijian6's avatar
lijian6 committed
1923
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
lijian6's avatar
lijian6 committed
1924
    EP_HOST_ASSERT(type == HIP_R_16BF);
Chenggang Zhao's avatar
Chenggang Zhao committed
1925

lijian6's avatar
lijian6 committed
1926
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1927
1928
1929
    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
1930

Chenggang Zhao's avatar
Chenggang Zhao committed
1931
1932
1933
} // namespace internode

} // namespace deep_ep
lijian6's avatar
lijian6 committed
1934

lijian6's avatar
lijian6 committed
1935
1936
1937
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
lijian6's avatar
lijian6 committed
1938
1939

#endif