internode.cu 105 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
3
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include "launch.cuh"
#include "utils.cuh"
lijian6's avatar
lijian6 committed
6
7
8
9
10
11

#ifndef DISABLE_ROCSHMEM

#include <rocshmem/rocshmem.hpp>

// TODO: fix unroll warnings
lijian6's avatar
lijian6 committed
12
13
14
15
16
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
Chenggang Zhao's avatar
Chenggang Zhao committed
17
18
19
20
21

namespace deep_ep {

namespace internode {

lijian6's avatar
lijian6 committed
22
extern rocshmem::rocshmem_team_t cpu_rdma_team;
Chenggang Zhao's avatar
Chenggang Zhao committed
23
24
25
26
27
28
29
30
31

struct SourceMeta {
    int src_rdma_rank, is_token_in_nvl_rank_bits;

    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");

    __forceinline__ SourceMeta() = default;

    // TODO: faster encoding
lijian6's avatar
lijian6 committed
32
33
    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
        src_rdma_rank             = rdma_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
34
        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
lijian6's avatar
lijian6 committed
35
36
#pragma unroll
        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
37
38
39
40
41
42
43
44
45
46
47
48
            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
    }

    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
        return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
    }
};

int get_source_meta_bytes() {
    return sizeof(SourceMeta);
}

lijian6's avatar
lijian6 committed
49
50
51
52
53
54
55
56
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
                                                                     int num_scales,
                                                                     int num_topk_idx,
                                                                     int num_topk_weights) {
    return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
                                      num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
                                      num_topk_weights * sizeof(float),
                                  sizeof(int4)));
Chenggang Zhao's avatar
Chenggang Zhao committed
57
58
}

lijian6's avatar
lijian6 committed
59
60
61
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                    int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
lijian6's avatar
lijian6 committed
62
    // Return `int32_t` offset and count to clean
lijian6's avatar
lijian6 committed
63
64
65
66
    return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
             num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
                sizeof(int),
            (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
Chenggang Zhao's avatar
Chenggang Zhao committed
67
}
lijian6's avatar
lijian6 committed
68
69
70
71
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
                   int num_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
72
    // Return `int32_t` offset and to clean
lijian6's avatar
lijian6 committed
73
74
    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
                              "Invalid size of `SourceMeta`");
Chenggang Zhao's avatar
Chenggang Zhao committed
75
    return {
lijian6's avatar
lijian6 committed
76
77
78
79
80
81
        (num_nvl_recv_buffer_tokens *
         (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
          num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
         num_nvl_ranks * num_sms) /
            sizeof(int),
        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
Chenggang Zhao's avatar
Chenggang Zhao committed
82
83
84
85
    };
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
86
87
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
                                                       const int nvl_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
88
89
90
91
    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
92
93
94
95
96
97
98
99
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
    // NOTE: shmem_device_barrier_all() might be an issue as
    // it doesn't follow OpenSHMEM specification on ROCm
    // kLowLatencyMode
    //     ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
    //     : rocshmem::rocshmem_barrier_all();
    rocshmem::rocshmem_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
100
101
102
103
}

template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
lijian6's avatar
lijian6 committed
104
105
106
107
108
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
                int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
Chenggang Zhao's avatar
Chenggang Zhao committed
109
                const int nvl_clean_offset, const int nvl_num_int_clean,
lijian6's avatar
lijian6 committed
110
111
112
113
114
115
116
117
                int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
                int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
                void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
                const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id     = static_cast<int>(blockIdx.x);
    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
         lane_id     = get_lane_id();
    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
118
119

    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
120
121
    auto num_rdma_experts = num_experts / kNumRDMARanks,
         num_nvl_experts  = num_rdma_experts / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
122
123
124

    if (sm_id == 0) {
        // Communication with others
lijian6's avatar
lijian6 committed
125
        // Global barrier: the first warp do intra-node sync, the second warp do internode sync
Chenggang Zhao's avatar
Chenggang Zhao committed
126
127
        EP_DEVICE_ASSERT(num_warps > 1);
        EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
lijian6's avatar
lijian6 committed
128
129
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
130

lijian6's avatar
lijian6 committed
131
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
132
133
        __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
134
        // Send numbers of tokens per rank/expert to RDMA ranks
lijian6's avatar
lijian6 committed
135
136
137
        auto rdma_buffer_ptr_int        = reinterpret_cast<int *>(rdma_buffer_ptr);
        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
            rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
Chenggang Zhao's avatar
Chenggang Zhao committed
138
139

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
140
141
        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
                                  rdma_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
142
143
144
145
146
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;

        // Copy to send buffer
        for (int i = thread_id; i < num_ranks; i += num_threads)
lijian6's avatar
lijian6 committed
147
148
            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
                num_tokens_per_rank[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
149
        for (int i = thread_id; i < num_experts; i += num_threads)
lijian6's avatar
lijian6 committed
150
151
152
            rdma_recv_num_tokens_mixed.send_buffer(
                i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
                num_tokens_per_expert[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
153
        if (thread_id < kNumRDMARanks)
lijian6's avatar
lijian6 committed
154
155
156
            rdma_recv_num_tokens_mixed.send_buffer(
                thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
                num_tokens_per_rdma_rank[thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
157
158
159
160
161
        __syncthreads();

        // Issue send
        // TODO: more light fence or barrier or signaling
        // TODO: overlap EP barrier and NVL cleaning
lijian6's avatar
lijian6 committed
162
163
164
165
166
167
        if (thread_id < kNumRDMARanks) {
            rocshmem::rocshmem_int_put_nbi(
                rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
                rdma_recv_num_tokens_mixed.send_buffer(thread_id),
                NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
                translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
168
        }
alpha-baby's avatar
alpha-baby committed
169
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
170
        if (thread_id == 0)
lijian6's avatar
lijian6 committed
171
172
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

Chenggang Zhao's avatar
Chenggang Zhao committed
173
174
175
176
177
        __syncthreads();

        // NVL buffers
        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
178
179
180
181
182
183
184
185
186
187
        auto nvl_reduced_num_tokens_per_expert =
            Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
        auto nvl_send_num_tokens_per_rank =
            AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_send_num_tokens_per_expert =
            AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_rank =
            AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_expert =
            AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
Chenggang Zhao's avatar
Chenggang Zhao committed
188
189

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
190
191
192
193
194
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
                                      nvl_send_num_tokens_per_rank.total_bytes +
                                      nvl_send_num_tokens_per_expert.total_bytes <=
                                  nvl_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
195
196
197
198
199
200
201
202
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;

        // Reduce number of tokens per expert into the NVL send buffer
        // TODO: may use NVSHMEM reduction
        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
        if (thread_id < num_rdma_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
203
204
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
205
206
207
208
209
210
211
212
                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
            nvl_reduced_num_tokens_per_expert[thread_id] = sum;
        }
        __syncthreads();

        // Reduce RDMA received tokens
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
213
214
215
216
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i) {
                sum +=
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
Chenggang Zhao's avatar
Chenggang Zhao committed
217
218
                recv_rdma_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
219
220
            while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
221
222
223
224
225
226
            *moe_recv_rdma_counter_mapped = sum;
        }

        // Send numbers of tokens per rank/expert to NVL ranks
        EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
227
228
229
230
231
232
233
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
            for (int i = 0; i < num_nvl_experts; ++i)
                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
                    nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
234
        }
lijian6's avatar
lijian6 committed
235
236
        memory_fence();
        __syncthreads();
237
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
238
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
239

lijian6's avatar
lijian6 committed
240
        // Reduce number of tokens per rank/expert
Chenggang Zhao's avatar
Chenggang Zhao committed
241
242
243
        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
244
            for (int i = 0; i < num_ranks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
245
246
247
248
                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
                recv_gbl_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
249
250
            while (ld_volatile_global(moe_recv_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
251
252
253
254
            *moe_recv_counter_mapped = sum;
        }
        if (thread_id < num_nvl_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
255
256
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
257
258
                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
lijian6's avatar
lijian6 committed
259
260
            while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
261
262
263
264
            moe_recv_expert_counter_mapped[thread_id] = sum;
        }

        // Finally barrier
lijian6's avatar
lijian6 committed
265
266
267
268
        __syncthreads();
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

269
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
270
271
272
273
274
    } else {
        // Calculate meta data
        int dst_rdma_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
275
276
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
277
278
279

            // Iterate over tokens
            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
lijian6's avatar
lijian6 committed
280
281
282
283
284
285
286
287
288
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
                                          "Invalid number of NVL peers");
                auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
                    is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
                auto is_token_in_rank_values =
                    reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
289
290
291
292
293
294
                    per_nvl_rank_count[j] += is_token_in_rank_values[j];
                total_count += (is_token_in_rank_uint64 != 0);
            }

            // Warp reduce
            total_count = warp_reduce_sum(total_count);
lijian6's avatar
lijian6 committed
295
296
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
297
298
299
                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);

            // Write into channel matrix
lijian6's avatar
lijian6 committed
300
301
302
303
304
305
            if (lane_id == 0) {
#pragma unroll
                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
                                                  num_channels +
                                              channel_id] = per_nvl_rank_count[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
306
307
308
309
310
311
                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
            }
        }

        // Calculate prefix sum
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
312
        if (thread_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
313
            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
lijian6's avatar
lijian6 committed
314
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
315
316
317
                prefix_row[i] += prefix_row[i - 1];
        }

lijian6's avatar
lijian6 committed
318
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
319
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
320
321
322
            auto prefix_row = gbl_channel_prefix_matrix +
                              (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
323
324
325
326
327
                prefix_row[i] += prefix_row[i - 1];
        }
    }
}

lijian6's avatar
lijian6 committed
328
329
330
331
332
333
334
335
336
337
338
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
339
                     bool low_latency_mode) {
lijian6's avatar
lijian6 committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                \
    {                                                                                              \
        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks>       \
                                                     : notify_dispatch<false, num_rdma_ranks>;     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks,   \
            num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert,         \
            moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens,             \
            num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second,         \
            nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix,               \
            recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,        \
            rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team);               \
    }                                                                                              \
    break

    constexpr int kNumThreads    = 256;
    const auto    num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
357
358

    // Get clean meta
lijian6's avatar
lijian6 committed
359
360
361
362
363
364
365
366
367
368
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());

    // Launch kernel
    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}

lijian6's avatar
lijian6 committed
383
384
385
386

template <bool kLowLatencyMode,
          int kNumRDMARanks,
          bool kCachedMode,
lijian6's avatar
lijian6 committed
387
388
          int kNumDispatchRDMASenderWarps,
          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
lijian6's avatar
lijian6 committed
389
390
391
392
393
394
395
396
397
398
399
400
401
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
        SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
        const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
        int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
        int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
        const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
        const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
        int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
        int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
        int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
        int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
        int num_ranks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
402
    enum class WarpRole {
lijian6's avatar
lijian6 committed
403
404
405
406
407
        kRDMASender,            // 从x写入到RDMA发送缓存
        kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
        kRDMAAndNVLForwarder,   // 从RDMA接收缓存转写到ipc nvl缓存
        kForwarderCoordinator,  // 向远端RDMA确认接收
        kNVLReceivers           // 从nvl缓存写入到recv_x
Chenggang Zhao's avatar
Chenggang Zhao committed
408
409
    };

lijian6's avatar
lijian6 committed
410
411
412
413
414
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
lijian6's avatar
lijian6 committed
415
416
417
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
418
419
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

lijian6's avatar
lijian6 committed
420
421
422
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
    EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);

Chenggang Zhao's avatar
Chenggang Zhao committed
423
    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
lijian6's avatar
lijian6 committed
424
425
426
427
428
429
430
431
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASender, -1};
            } else if(warp_id == kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASenderCoordinator, -1};
            }
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            if(warp_id < NUM_MAX_NVL_PEERS) {
Chenggang Zhao's avatar
Chenggang Zhao committed
432
433
434
435
436
                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
            } else {
                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
            }
        } else {
lijian6's avatar
lijian6 committed
437
            return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
Chenggang Zhao's avatar
Chenggang Zhao committed
438
439
        }
    }();
lijian6's avatar
lijian6 committed
440

lijian6's avatar
lijian6 committed
441
    auto warp_role = role_meta.first;
Chenggang Zhao's avatar
Chenggang Zhao committed
442
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
lijian6's avatar
lijian6 committed
443
444
445
    // if(lane_id==0){
    //     printf("tid=%d, bid=%d, warp_role=%d\n", threadIdx.x, blockIdx.x, warp_role);
    // }
Chenggang Zhao's avatar
Chenggang Zhao committed
446
447

    // RDMA symmetric layout
lijian6's avatar
lijian6 committed
448
449
450
451
452
453
    auto hidden_bytes             = hidden_int4 * sizeof(int4);
    auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
    auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
454
455

    // NVL buffer layouts
lijian6's avatar
lijian6 committed
456
    // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
Chenggang Zhao's avatar
Chenggang Zhao committed
457
    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
lijian6's avatar
lijian6 committed
458
    int rs_wr_rank = 0, ws_rr_rank = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
459
    if (warp_role == WarpRole::kRDMAAndNVLForwarder)
lijian6's avatar
lijian6 committed
460
        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
461
    if (warp_role == WarpRole::kNVLReceivers)
lijian6's avatar
lijian6 committed
462
        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
463
464

    // Allocate buffers
lijian6's avatar
lijian6 committed
465
466
467
468
469
470
471
472
473
    auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
    auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
474
475

    // RDMA sender warp synchronization
lijian6's avatar
lijian6 committed
476
477
478
    __shared__ volatile int rdma_send_next_token_idx;
    __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
    __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
479

lijian6's avatar
lijian6 committed
480
481
    // NVL and RDMA coordinate Forward warp synchronization
    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
482
    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
lijian6's avatar
lijian6 committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

    // Place the main logic of your kernel here, using the parameters above.
    if(warp_role == WarpRole::kRDMASender) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
        然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
        最后,它复制相关的数据到对称发送缓冲区。

        kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
        期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
        同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
        */
        // 获取任务范围
        int token_start_idx, token_end_idx;
        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

        // 清理共享内存
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
        if(warp_id == 0 && lane_id == 0) {
            rdma_send_next_token_idx = token_start_idx;
lijian6's avatar
lijian6 committed
504
        }
lijian6's avatar
lijian6 committed
505
506
507
        if(warp_id == 0 && lane_id < kNumRDMARanks) {
            rdma_send_channel_tail[lane_id]      = 0;
            rdma_send_channel_next_tail[lane_id] = 0;
lijian6's avatar
lijian6 committed
508
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
509

lijian6's avatar
lijian6 committed
510
511
512
513
514
515
        // 发送本通道中的令牌数量,通过 `-value - 1` 表示
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
        // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
        // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
        for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
            auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
516
            if (lane_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
517
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
518
            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
519
                dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
520
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
521
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
522
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
lijian6's avatar
lijian6 committed
523
                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
524
            }
lijian6's avatar
lijian6 committed
525

lijian6's avatar
lijian6 committed
526
527
528
            syncwarp();
            if (dst_rdma_rank != rdma_rank) {
                rocshmem::rocshmem_ctx_int_put_nbi_wave(
lijian6's avatar
lijian6 committed
529
530
531
                ctx, rdma_channel_meta.recv_buffer(rdma_rank),
                rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
                translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
532
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
533
        }
lijian6's avatar
lijian6 committed
534
        rocshmem::rocshmem_ctx_quiet(ctx);
lijian6's avatar
lijian6 committed
535
536
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
537

lijian6's avatar
lijian6 committed
538
        // 遍历令牌并复制到缓冲区
Chenggang Zhao's avatar
Chenggang Zhao committed
539
        int64_t token_idx;
lijian6's avatar
lijian6 committed
540
541
542
543
        int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
        for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
            // 读取RDMA秩的存在性
Chenggang Zhao's avatar
Chenggang Zhao committed
544
            uint64_t is_token_in_rank_uint64 = 0;
lijian6's avatar
lijian6 committed
545
546
547
            if(lane_id < kNumRDMARanks) {
                is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
548

lijian6's avatar
lijian6 committed
549
550
551
552
            // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
            while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
                // 等待
            }
lijian6's avatar
lijian6 committed
553
            syncwarp();
554

lijian6's avatar
lijian6 committed
555
            // 获取下一个尾部位置
lijian6's avatar
lijian6 committed
556
            int rdma_tail_idx = -1;
lijian6's avatar
lijian6 committed
557
            if(is_token_in_rank_uint64 != 0) {
lijian6's avatar
lijian6 committed
558
                rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
lijian6's avatar
lijian6 committed
559
560
561
562
563

                // 与kForwarderCoordinator相互配合,调节发送数据的频率
                while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
                    cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
564
            }
lijian6's avatar
lijian6 committed
565
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
566

lijian6's avatar
lijian6 committed
567
568
            // 存储RDMA头部以供合并
            if(lane_id < kNumRDMARanks && !kCachedMode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
569
                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
lijian6's avatar
lijian6 committed
570
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
571

lijian6's avatar
lijian6 committed
572
573
574
575
            // 更新最后一个令牌尾部
            if(last_rdma_tail_idx >= 0) {
                st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
            }
lijian6's avatar
lijian6 committed
576
577
            last_rdma_tail_idx = rdma_tail_idx;

lijian6's avatar
lijian6 committed
578
579
580
581
            // 释放顺序锁
            if(lane_id == 0) {
                rdma_send_next_token_idx += 1;
            }
lijian6's avatar
lijian6 committed
582

lijian6's avatar
lijian6 committed
583
            // 广播尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
584
            SourceMeta src_meta;
lijian6's avatar
lijian6 committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
            void* dst_send_buffers[kNumTopkRDMARanks];
            /*
            该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
            */
            #pragma unroll
            for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
                // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
                if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
                    // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
                    //  计算slot_idx在接收缓冲区中的位置
                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;

                    // 存储当前RDMA秩到topk_ranks数组中
                    topk_ranks[num_topk_ranks] = i;

                    // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
lijian6's avatar
lijian6 committed
602
                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
lijian6's avatar
lijian6 committed
603
604
605
606
                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);

                    // 如果当前lane_id等于num_topk_ranks,则更新src_meta
                    if(lane_id == num_topk_ranks) {
lijian6's avatar
lijian6 committed
607
                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
lijian6's avatar
lijian6 committed
608
609
610
611
612
                    }

                    // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
                    // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
                    dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
lijian6's avatar
lijian6 committed
613
                }
lijian6's avatar
lijian6 committed
614
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
615
616
            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);

lijian6's avatar
lijian6 committed
617
618
619
620
621
622
            // 复制 `x` 到对称发送缓冲区
            auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
623
            };
lijian6's avatar
lijian6 committed
624
625
626
627
628
            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
            }
lijian6's avatar
lijian6 committed
629

lijian6's avatar
lijian6 committed
630
631
632
633
634
635
636
637
            // 复制源元数据到对称发送缓冲区
            if(lane_id < num_topk_ranks) {
                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
            }
lijian6's avatar
lijian6 committed
638

lijian6's avatar
lijian6 committed
639
640
641
            // 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_scales; i += kWarpSize) {
lijian6's avatar
lijian6 committed
642
                auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
lijian6's avatar
lijian6 committed
643
644
645
646
647
648
649
650
651
652
653

                // auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
                // auto value = ld_nc_global(x_scales + offset);
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
                }
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
lijian6's avatar
lijian6 committed
654
            }
655

lijian6's avatar
lijian6 committed
656
657
658
            // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
Chenggang Zhao's avatar
Chenggang Zhao committed
659
                auto rank_idx = i / num_topk, copy_idx = i % num_topk;
lijian6's avatar
lijian6 committed
660
                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
661
                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
lijian6's avatar
lijian6 committed
662
663
                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
664
            }
lijian6's avatar
lijian6 committed
665
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
666

lijian6's avatar
lijian6 committed
667
668
669
670
671
672
        // 结尾部分
        // 获取顺序锁
        while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
            // 等待
        }

lijian6's avatar
lijian6 committed
673
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
674

lijian6's avatar
lijian6 committed
675
676
677
678
        // 更新最后一个令牌尾部
        if(last_rdma_tail_idx >= 0) {
            st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
679

lijian6's avatar
lijian6 committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        // 释放顺序锁
        if(lane_id == 0) {
            rdma_send_next_token_idx += 1;
        }
    } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
        如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
        最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。

        kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
        nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
        */
        if(warp_id > kNumDispatchRDMASenderWarps) {
            return;
        }
        // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
699

lijian6's avatar
lijian6 committed
700
701
702
        // 同步共享内存,确保所有线程在继续之前都达到了这一点
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
703

lijian6's avatar
lijian6 committed
704
        // 计算当前通道需要发送的令牌数
Chenggang Zhao's avatar
Chenggang Zhao committed
705
        int num_tokens_to_send = 0;
lijian6's avatar
lijian6 committed
706
        if(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
707
            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
lijian6's avatar
lijian6 committed
708
709
            if(channel_id > 0)
                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
710
711
        }

lijian6's avatar
lijian6 committed
712
        // 记录上次发出的尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
713
        int last_issued_tail = 0;
lijian6's avatar
lijian6 committed
714
715
716
717
718
719
720
        // 当有任何RDMA秩需要发送令牌时,继续循环
        while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
            for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
                // 计算目标RDMA秩
                int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;

                // 获取同步后的需要发送的令牌数
lijian6's avatar
lijian6 committed
721
                synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
722

lijian6's avatar
lijian6 committed
723
724
725
726
                if(synced_num_tokens_to_send == 0)
                    continue; // 如果没有令牌需要发送,则跳过

                // 读取进度
lijian6's avatar
lijian6 committed
727
                auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
lijian6's avatar
lijian6 committed
728
729
730
731
732
                auto processed_tail          = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
                auto num_tokens_processed    = processed_tail - synced_last_issued_tail;

                // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
                if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
733
734
                    continue;

lijian6's avatar
lijian6 committed
735
736
737
738
739
740
                // 计算本次需要发出的令牌数
                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);

                // 发出RDMA发送请求
                if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
741
                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
742
                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
743
744
745
746
747
748
749
750
751
                    rocshmem::rocshmem_ctx_schar_put_nbi_wave(
                        ctx,
                        rdma_channel_data.recv_buffer(rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        rdma_channel_data.send_buffer(dst_rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        num_bytes_per_rdma_token * num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
                    rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
752
                } else {
lijian6's avatar
lijian6 committed
753
                    // 对于本地RDMA秩,使用较轻的内存屏障
Chenggang Zhao's avatar
Chenggang Zhao committed
754
755
756
                    memory_fence();
                }

lijian6's avatar
lijian6 committed
757
                // 更新尾部位置
lijian6's avatar
lijian6 committed
758
                syncwarp();
lijian6's avatar
lijian6 committed
759
                if(lane_id == dst_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
760
761
                    last_issued_tail += num_tokens_to_issue;
                    num_tokens_to_send -= num_tokens_to_issue;
lijian6's avatar
lijian6 committed
762
                    // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
lijian6's avatar
lijian6 committed
763
764
765
                    rocshmem::rocshmem_ctx_ulong_atomic_add(
                        ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
766
767
                }
            }
lijian6's avatar
lijian6 committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        } // while(__any(num_tokens_to_send > 0))
    } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
        它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
        接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
        然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
        最后,它同步头部和尾部索引,并标记通道为退役状态。
        */
        // RDMA消费者和NVL生产者
        const auto dst_nvl_rank          = target_rank;                                       // 目标NVL秩
        const auto dst_rank              = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;      // 目标秩
        const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);              // 目标秩专家开始
        const auto dst_rank_expert_end   = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束

        // 等待计数器到达
Chenggang Zhao's avatar
Chenggang Zhao committed
784
        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
lijian6's avatar
lijian6 committed
785
786
        EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
787
788
789
790
791
792
793
794
795
        if(lane_id < kNumRDMARanks) {
            while(true) {
                // 对应于kRDMASender中的数据写入
                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);                     // 是nvl节点的起始地址
                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);            // 本rdma节点的起始地址
                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);        // 本节点的结束地址
                if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
                    // 通知NVL秩
Chenggang Zhao's avatar
Chenggang Zhao committed
796
                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
lijian6's avatar
lijian6 committed
797
798
799
                    EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);

                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
800
801
                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);

lijian6's avatar
lijian6 committed
802
803
                    // 保存从RDMA通道接收的令牌计数
                    src_rdma_channel_prefix = -meta_2 - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
804
                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;
lijian6's avatar
lijian6 committed
805
806
807
808
809
                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
                    if(!kCachedMode)
                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;

                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
Chenggang Zhao's avatar
Chenggang Zhao committed
810
811
812
813
                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
                    break;
                }

lijian6's avatar
lijian6 committed
814
815
816
817
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
                           channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
Chenggang Zhao's avatar
Chenggang Zhao committed
818
819
820
821
                    trap();
                }
            }
        }
lijian6's avatar
lijian6 committed
822
        syncwarp();
lijian6's avatar
lijian6 committed
823
824

        // 移动缓存的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
825
826
        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;

lijian6's avatar
lijian6 committed
827
828
829
        // 等待共享内存被清理
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
830

lijian6's avatar
lijian6 committed
831
832
833
834
        // 开始准备处理接受数据,直到所有的数据接受完成。
        // 转发从RDMA缓冲区的令牌
        // 注意:总是从本地秩开始
        int src_rdma_rank = sm_id % kNumRDMARanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
835
836
        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
lijian6's avatar
lijian6 committed
837
838
        while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
            // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
lijian6's avatar
lijian6 committed
839
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
840
841
842

            // 用于给kNVLReceivers进行互动,控制数据的传输速度
            while(lane_id == 0) {
lijian6's avatar
lijian6 committed
843
                int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
lijian6's avatar
lijian6 committed
844
                if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
845
                    break;
lijian6's avatar
lijian6 committed
846
                cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
847

lijian6's avatar
lijian6 committed
848
849
850
851
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
852
853
854
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
855
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
856

lijian6's avatar
lijian6 committed
857
            // 找到下一个源RDMA秩(轮询)
lijian6's avatar
lijian6 committed
858
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
859
            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
860
                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
lijian6's avatar
lijian6 committed
861
862
863
864
865
                if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
                    if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));

                    if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
866
                        break;
lijian6's avatar
lijian6 committed
867
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
868
869
                }

lijian6's avatar
lijian6 committed
870
871
872
873
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
Chenggang Zhao's avatar
Chenggang Zhao committed
874
875
876
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
877

lijian6's avatar
lijian6 committed
878
879
            auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
            auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
880

lijian6's avatar
lijian6 committed
881
882
883
884
885
886
887
888
889
890
            // 遍历RDMA缓冲区中的每一个令牌
            for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
                // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
                void* shifted           = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
                auto src_meta           = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
                if(lane_id == src_rdma_rank) {
                    num_tokens_to_recv_from_rdma -= 1;
                }

Chenggang Zhao's avatar
Chenggang Zhao committed
891
                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
lijian6's avatar
lijian6 committed
892
                if(lane_id == src_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
893
894
                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
                    rdma_nvl_token_idx += is_in_dst_nvl_rank;
lijian6's avatar
lijian6 committed
895
                    if(!kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
896
897
                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
                }
lijian6's avatar
lijian6 committed
898
899

                if(!is_in_dst_nvl_rank)
Chenggang Zhao's avatar
Chenggang Zhao committed
900
901
                    continue;

lijian6's avatar
lijian6 committed
902
                // 获取一个空闲槽位
lijian6's avatar
lijian6 committed
903
                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
904

lijian6's avatar
lijian6 committed
905
                // 复制数据
lijian6's avatar
lijian6 committed
906
907
                UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
                                   nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
lijian6's avatar
lijian6 committed
908
909
910
                                   reinterpret_cast<int4*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
lijian6's avatar
lijian6 committed
911

lijian6's avatar
lijian6 committed
912
913
                // 复制源元数据
                if(lane_id == 0)
lijian6's avatar
lijian6 committed
914
                    st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
lijian6's avatar
lijian6 committed
915
                shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
lijian6's avatar
lijian6 committed
916

lijian6's avatar
lijian6 committed
917
                // 复制 `x_scales`
lijian6's avatar
lijian6 committed
918
919
                UNROLLED_WARP_COPY(1, lane_id, num_scales,
                                   nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
lijian6's avatar
lijian6 committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
                                   reinterpret_cast<float*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<float*>(shifted) + num_scales;

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
                    // 读取
                    auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
                    shifted = reinterpret_cast<int*>(shifted) + num_topk;
                    auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);

                    // 转换和写入
                    idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
                    st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
lijian6's avatar
lijian6 committed
934
                    weight_value = idx_value >= 0 ? weight_value : 0.0f;
lijian6's avatar
lijian6 committed
935
                    st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
936
937
                }

lijian6's avatar
lijian6 committed
938
939
                // 在NVL缓冲区不足的情况下,提前停止
                if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
940
941
942
                    src_rdma_tail = i + 1;
            }

lijian6's avatar
lijian6 committed
943
944
945
            // 同步头部索引
            if(lane_id == src_rdma_rank)
                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
946

lijian6's avatar
lijian6 committed
947
            // 移动尾部索引,与kNVLReceivers互相通信使用
lijian6's avatar
lijian6 committed
948
            syncwarp();
lijian6's avatar
lijian6 committed
949
950
951
            if(lane_id == 0) {
                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
952
953
954
        }

        // Retired
lijian6's avatar
lijian6 committed
955
        syncwarp();
lijian6's avatar
lijian6 committed
956
        if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
957
            forward_channel_retired[dst_nvl_rank] = true;
lijian6's avatar
lijian6 committed
958
959
960
961
962
963
964
965
966
        }
    } else if(warp_role == WarpRole::kForwarderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
        它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
        然后,它清理共享内存,并初始化转发通道的头部和退役状态。
        接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
        否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
        */
Chenggang Zhao's avatar
Chenggang Zhao committed
967
        // Extra warps for forwarder coordinator should exit directly
lijian6's avatar
lijian6 committed
968
        if (warp_id > NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
969
970
            return;

lijian6's avatar
lijian6 committed
971
972
973
974
975
976
        // 转发warp协调器
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        // 清理共享内存
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
        for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
977
            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
lijian6's avatar
lijian6 committed
978
        if(lane_id < NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
979
            forward_channel_retired[lane_id] = false;
lijian6's avatar
lijian6 committed
980
981
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
982
983

        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
lijian6's avatar
lijian6 committed
984
985
986

        while(true) {
            // 找到最小的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
987
            int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
988
#pragma unroll
lijian6's avatar
lijian6 committed
989
990
            for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                if(!forward_channel_retired[i])
lijian6's avatar
lijian6 committed
991
                    min_head = min(min_head, forward_channel_head[i][target_rdma]);
lijian6's avatar
lijian6 committed
992
993

            if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
Chenggang Zhao's avatar
Chenggang Zhao committed
994
                break;
lijian6's avatar
lijian6 committed
995
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
996

lijian6's avatar
lijian6 committed
997
998
            // 更新远程头部
            if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
lijian6's avatar
lijian6 committed
999
1000
1001
                rocshmem::rocshmem_ctx_ulong_atomic_add(
                    ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
                    translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
1002
1003
                last_head = min_head;
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1004

lijian6's avatar
lijian6 committed
1005
            // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1006
            __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1007
        }
lijian6's avatar
lijian6 committed
1008
1009
1010
1011
1012
1013
1014
1015
    } else if(warp_role == WarpRole::kNVLReceivers) {
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }

        // Place the main logic of your kernel here, using the parameters above.
        // NVL消费者
        // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
Chenggang Zhao's avatar
Chenggang Zhao committed
1016
        int src_nvl_rank = target_rank, total_offset = 0;
lijian6's avatar
lijian6 committed
1017
1018
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1019
1020
            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];

lijian6's avatar
lijian6 committed
1021
1022
        // 接收通道偏移
        int start_offset = 0, end_offset = 0, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
1023
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1024
1025

        while(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1026
            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1027
            end_offset   = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1028
            if(start_offset < 0 && end_offset < 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1029
1030
1031
1032
                start_offset = -start_offset - 1, end_offset = -end_offset - 1;
                total_offset += start_offset;
                break;
            }
lijian6's avatar
lijian6 committed
1033
1034
1035
1036
            // 超时检查
            if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
                        channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
1037
1038
1039
                trap();
            }
        }
lijian6's avatar
lijian6 committed
1040

Chenggang Zhao's avatar
Chenggang Zhao committed
1041
1042
        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

lijian6's avatar
lijian6 committed
1043
1044
1045
        // 保存以供合并使用
        if(lane_id < kNumRDMARanks && !kCachedMode)
            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
lijian6's avatar
lijian6 committed
1046
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1047
1048

        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1049
1050
        while(num_tokens_to_recv > 0) {
            // 通过通道0检查通道状态
lijian6's avatar
lijian6 committed
1051
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1052
1053
1054
            while(lane_id == 0) {
                // 准备复制
                if(cached_channel_head_idx != cached_channel_tail_idx)
Chenggang Zhao's avatar
Chenggang Zhao committed
1055
                    break;
lijian6's avatar
lijian6 committed
1056
1057
1058
1059
1060
                cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
                            channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1061
1062
1063
1064
                    trap();
                }
            }

lijian6's avatar
lijian6 committed
1065
            // 同步队列尾部
lijian6's avatar
lijian6 committed
1066
1067
            cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);

lijian6's avatar
lijian6 committed
1068
            // 复制数据
Chenggang Zhao's avatar
Chenggang Zhao committed
1069
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1070
1071
1072
1073
            for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
                auto meta               = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
                int64_t recv_token_idx  = shfl_sync(total_offset, meta.src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1074
1075
                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;

lijian6's avatar
lijian6 committed
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
                // 复制数据
                UNROLLED_WARP_COPY(5,
                                lane_id,
                                hidden_int4,
                                recv_x + recv_token_idx * hidden_int4,
                                nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
                                ld_nc_global,
                                st_na_global);

                // 复制源元数据
                if(lane_id == 0 && !kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
1087
                    st_na_global(recv_src_meta + recv_token_idx, meta);
lijian6's avatar
lijian6 committed
1088

lijian6's avatar
lijian6 committed
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
                // 复制比例
                UNROLLED_WARP_COPY(1,
                                lane_id,
                                num_scales,
                                recv_x_scales + recv_token_idx * num_scales,
                                nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
                                ld_nc_global,
                                st_na_global);

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
lijian6's avatar
lijian6 committed
1100
1101
                    auto recv_idx   = recv_token_idx * num_topk + lane_id;
                    auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
lijian6's avatar
lijian6 committed
1102
1103
                    st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
                    st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1104
1105
1106
                }
            }

lijian6's avatar
lijian6 committed
1107
            // 移动队列
lijian6's avatar
lijian6 committed
1108
            syncwarp();
lijian6's avatar
lijian6 committed
1109
            if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1110
                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
lijian6's avatar
lijian6 committed
1111
1112
            }
        } // while(num_tokens_to_recv > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1113
    }
lijian6's avatar
lijian6 committed
1114
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1115
1116
}

lijian6's avatar
lijian6 committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1130
    constexpr int kNumDispatchRDMASenderWarps = 7;
lijian6's avatar
lijian6 committed
1131
1132
    // Make sure never OOB
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
lijian6's avatar
lijian6 committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                       \
    {                                                                                              \
        auto dispatch_func =                                                                       \
            low_latency_mode                                                                       \
                ? (is_cached_dispatch                                                              \
                       ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>         \
                       : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>)       \
                : (is_cached_dispatch                                                              \
                       ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>        \
                       : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>);     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx,   \
            recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta),                      \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head,   \
            send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix,        \
            rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix,      \
            recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales,       \
            num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,       \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
1158
1159
    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

lijian6's avatar
lijian6 committed
1160
1161
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
                        (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1162
1163
1164
1165
    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
1166
template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
1167
__global__ void __launch_bounds__(1024, 1)
lijian6's avatar
lijian6 committed
1168
1169
1170
1171
1172
1173
1174
1175
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
              const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
              int num_channels, const int *rdma_channel_prefix_matrix,
              const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
              void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
              bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id       = static_cast<int>(blockIdx.x);
    auto thread_id   = static_cast<int>(threadIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
1176
    auto num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
1177
1178
1179
    auto num_warps   = num_threads / kWarpSize;
    auto warp_id     = thread_id / kWarpSize;
    auto lane_id     = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
1180

lijian6's avatar
lijian6 committed
1181
    auto nvl_rank       = rank % NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1182
1183
1184
1185
1186
    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Using two SMs, which clean the RDMA/NVL buffer respectively
    if (sm_id == 0) {
        // Barrier for RDMA
lijian6's avatar
lijian6 committed
1187
1188
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Shangyan Zhou's avatar
Fix  
Shangyan Zhou committed
1189

lijian6's avatar
lijian6 committed
1190
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1191

lijian6's avatar
lijian6 committed
1192
1193
        // Clean
        auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
1194
1195
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1196
1197
1198
1199
1200
1201
        rocshmem::rocshmem_fence();
        __syncthreads();

        // Barrier again
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1202

lijian6's avatar
lijian6 committed
1203
1204
1205
1206
1207
1208
1209
    } else if (sm_id == 1) {
        // Barrier for NVL
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
        __syncthreads();

        // Clean
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
1210
1211
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1212
        memory_fence();
1213
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1214
1215

        // Barrier again
1216
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
1217
    } else if (sm_id == 2) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1218
1219
1220
1221
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1222
        EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
1223
1224
1225
1226

        // Iterate in reverse order
        if (lane_id < num_rdma_ranks and warp_id < num_channels) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1227
1228
            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1229
1230
1231

            // NOTES: `1 << 25` is a heuristic large number
            int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1232
1233
1234
            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                auto current_head =
                    __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
                if (current_head < 0) {
                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
                } else {
                    last_head = current_head;
                }
            }
        }
    } else {
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1247
1248
1249
        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
                                  rdma_rank_prefix_sum != nullptr);
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
1250

lijian6's avatar
lijian6 committed
1251
1252
1253
        if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
            for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks;
                 dst_rdma_rank += num_channels * 2 - 3) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1254
                // Iterate in reverse order
lijian6's avatar
lijian6 committed
1255
1256
1257
1258
1259
1260
                int token_start_idx =
                    warp_id == 0
                        ? 0
                        : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
                int token_end_idx =
                    rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
1261
1262
1263
1264
1265
                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
                token_start_idx += shift, token_end_idx += shift;

                // NOTES: `1 << 25` is a heuristic large number
                int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1266
1267
1268
1269
1270
1271
1272
                for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                    auto current_head =
                        __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
                    if (current_head < 0) {
                        combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
                    } else {
                        last_head = current_head;
1273
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1274
1275
1276
1277
1278
1279
1280
                }
            }
        }
    }
}

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
1281
1282
1283
1284
1285
1286
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
1287
                   bool is_cached_dispatch, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1288
    const int  num_threads    = ::max(128, kWarpSize * num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1289
1290
1291
    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Get clean meta
lijian6's avatar
lijian6 committed
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
1302
1303
1304
1305
1306
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_channels * 2 > 3);

    // Launch kernel
lijian6's avatar
lijian6 committed
1307
    auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
Chenggang Zhao's avatar
Chenggang Zhao committed
1308
    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
lijian6's avatar
lijian6 committed
1309
1310
1311
1312
1313
1314
    LAUNCH_KERNEL_NON_COOPERATIVE(
        &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
        nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
        num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
        rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
        cpu_rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1315
1316
}

lijian6's avatar
lijian6 committed
1317
1318
1319
1320
1321
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
                             int lane_id, int hidden_int4, int num_topk,
                             int4* combined_row, float* combined_topk_weights,
                             int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1322
1323
1324
1325
    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);

    // Broadcast current heads
    // Lane `i` holds the head of rank `i` and `is_token_in_rank`
lijian6's avatar
lijian6 committed
1326
    EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
Chenggang Zhao's avatar
Chenggang Zhao committed
1327
    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
lijian6's avatar
lijian6 committed
1328
1329
1330
1331
1332
    #pragma unroll
    for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
        slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
        topk_ranks[num_topk_ranks ++] = i;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1333
1334
1335
    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);

    // Reduce data
lijian6's avatar
lijian6 committed
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
    #pragma unroll
    for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
        // Read buffers
        // TODO: maybe too many registers here
        int4 recv_value_int4[kMaxNumRanks];
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j)
            recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);

        // Reduce all-to-all results
lijian6's avatar
lijian6 committed
1346
        float values[kDtypePerInt4] = {0};
lijian6's avatar
lijian6 committed
1347
1348
1349
1350
1351
1352
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j) {
            auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
            #pragma unroll
            for (int k = 0; k < kDtypePerInt4; ++ k)
                values[k] += static_cast<float>(recv_value_dtypes[k]);
Shangyan Zhou's avatar
Shangyan Zhou committed
1353
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1354

lijian6's avatar
lijian6 committed
1355
1356
1357
1358
1359
        // Cast back to `dtype_t` and write
        int4 out_int4;
        auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
        #pragma unroll
        for (int j = 0; j < kDtypePerInt4; ++ j)
lijian6's avatar
lijian6 committed
1360
1361
            out_dtypes[j] = static_cast<dtype_t>(values[j]);
        st_na_global(combined_row + i, out_int4);
Chenggang Zhao's avatar
Chenggang Zhao committed
1362
1363
1364
1365
1366
    }

    // Reduce `topk_weights`
    if (lane_id < num_topk) {
        float value = 0;
lijian6's avatar
lijian6 committed
1367
1368
        #pragma unroll
        for (int i = 0; i < num_topk_ranks; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1369
1370
1371
1372
1373
1374
1375
1376
            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
        st_na_global(combined_topk_weights + lane_id, value);
    }

    // Return the minimum top-k rank
    return topk_ranks[0];
}

lijian6's avatar
lijian6 committed
1377
1378
1379
1380
template <bool kLowLatencyMode,
          int kNumRDMARanks,
          typename dtype_t,
          int kNumCombineForwarderWarps,
lijian6's avatar
lijian6 committed
1381
          int kNumTopkRDMARanks     = get_num_topk_rdma_ranks(kNumRDMARanks),
lijian6's avatar
lijian6 committed
1382
          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
lijian6's avatar
lijian6 committed
1383
          int kNumForwarders        = kNumRDMARanks * kNumWarpsPerForwarder,
lijian6's avatar
lijian6 committed
1384
1385
1386
          int kNumRDMAReceivers     = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) 
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
lijian6's avatar
lijian6 committed
1387
1388
1389
1390
1391
1392
1393
1394
            const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
            const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
            const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
            const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
            int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
            int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
            int num_ranks) {
lijian6's avatar
lijian6 committed
1395
1396
1397
1398
1399
1400
1401
    enum class WarpRole {
        kNVLSender,
        kNVLAndRDMAForwarder,
        kRDMAReceiver,
        kRDMACoordinator,
        kNVLCoordinator
    };
Chenggang Zhao's avatar
Chenggang Zhao committed
1402

lijian6's avatar
lijian6 committed
1403
1404
1405
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

lijian6's avatar
lijian6 committed
1406
1407
1408
1409
1410
1411
    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

Chenggang Zhao's avatar
Chenggang Zhao committed
1412
1413
1414
1415
    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));

    // NOTES: we decouple a channel into 2 SMs
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
1416
1417
1418
1419
1420
1421
1422

    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumForwarders) {
                return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
Chenggang Zhao's avatar
Chenggang Zhao committed
1423
            } else {
lijian6's avatar
lijian6 committed
1424
                return {WarpRole::kRDMACoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1425
1426
            }
        } else {
lijian6's avatar
lijian6 committed
1427
            if(warp_id < kNumForwarders) {
lijian6's avatar
lijian6 committed
1428
                return {WarpRole::kRDMAReceiver, warp_id};
Chenggang Zhao's avatar
Chenggang Zhao committed
1429
            } else {
lijian6's avatar
lijian6 committed
1430
                return {WarpRole::kNVLCoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1431
1432
1433
            }
        }
    }();
lijian6's avatar
lijian6 committed
1434

Chenggang Zhao's avatar
Chenggang Zhao committed
1435
    auto warp_role = role_meta.first;
lijian6's avatar
lijian6 committed
1436
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
Chenggang Zhao's avatar
Chenggang Zhao committed
1437

lijian6's avatar
lijian6 committed
1438
    EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
1439
1440
    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1441
    // This approach is designed to sync multiple warps in a loop
lijian6's avatar
lijian6 committed
1442
1443
1444
1445
    constexpr int num_sync_large_iteration = 64;
    constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
    __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];   
    for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
lijian6's avatar
lijian6 committed
1446
1447
1448
1449
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
1450
    if (warp_role == WarpRole::kNVLSender) {
lijian6's avatar
lijian6 committed
1451
1452
1453
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1454

lijian6's avatar
lijian6 committed
1455
        const auto dst_nvl_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
1456
1457
1458
        // NVL layouts
        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
1459
1460
1461
1462
1463
        auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
1464

Chenggang Zhao's avatar
Chenggang Zhao committed
1465
1466
        // Get tasks for each RDMA lane
        int token_start_idx = 0, token_end_idx = 0;
lijian6's avatar
lijian6 committed
1467
1468
        if(lane_id < kNumRDMARanks) {
            int prefix_idx  = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
Chenggang Zhao's avatar
Chenggang Zhao committed
1469
            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
lijian6's avatar
lijian6 committed
1470
            token_end_idx   = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1471
        }
lijian6's avatar
lijian6 committed
1472
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1473
1474
1475

        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1476
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1477
1478

        // Iterate over all tokens and send by chunks
lijian6's avatar
lijian6 committed
1479
        while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1480
            // Exit if possible
lijian6's avatar
lijian6 committed
1481
            if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1482
1483
                break;

lijian6's avatar
lijian6 committed
1484
            // Decide next RDMA buffer to send
Chenggang Zhao's avatar
Chenggang Zhao committed
1485
            bool is_lane_ready = false;
lijian6's avatar
lijian6 committed
1486
            auto start_time    = wall_clock64();
lijian6's avatar
lijian6 committed
1487
1488

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1489
                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1490
                is_lane_ready      = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
lijian6's avatar
lijian6 committed
1491
1492
1493
                                num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;

                if(__any_sync(kFullWarpMask, is_lane_ready))
Chenggang Zhao's avatar
Chenggang Zhao committed
1494
                    break;
lijian6's avatar
lijian6 committed
1495

Chenggang Zhao's avatar
Chenggang Zhao committed
1496
                // Retry
lijian6's avatar
lijian6 committed
1497
1498
                if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1499
1500

                // Timeout check
lijian6's avatar
lijian6 committed
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
                if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
                        "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
                        channel_id,
                        rdma_rank,
                        nvl_rank,
                        dst_nvl_rank,
                        lane_id,
                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),
                        cached_channel_tail_idx,
                        token_start_idx,
                        token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1513
1514
1515
1516
1517
                    trap();
                }
            }

            // Sync token start index and count
lijian6's avatar
lijian6 committed
1518
1519
            for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
                if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1520
1521
1522
                    continue;

                // Sync token start index
lijian6's avatar
lijian6 committed
1523
1524
                auto token_idx          = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
                int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1525
1526

                // Send by chunk
lijian6's avatar
lijian6 committed
1527
                for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1528
1529
                    // Get an empty slot
                    int dst_slot_idx = 0;
lijian6's avatar
lijian6 committed
1530
1531
1532
                    if(lane_id == current_rdma_idx) {
                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
1533
                    }
lijian6's avatar
lijian6 committed
1534
                    dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
lijian6's avatar
lijian6 committed
1535
1536
1537
1538

                    // Copy data
                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                    auto shifted_x         = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
1539
                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
1540

lijian6's avatar
lijian6 committed
1541
                    // Copy source meta
lijian6's avatar
lijian6 committed
1542
1543
                    if(lane_id == 0)
                        st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1544

lijian6's avatar
lijian6 committed
1545
                    // Copy `topk_weights`
lijian6's avatar
lijian6 committed
1546
1547
1548
                    if(lane_id < num_topk)
                        st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
                                    ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1549
1550
1551
1552
1553
                }
                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
            }

            // Move queue tail
lijian6's avatar
lijian6 committed
1554
            syncwarp();
lijian6's avatar
lijian6 committed
1555
1556
1557
            if(lane_id < kNumRDMARanks and is_lane_ready) {
                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1558
1559
        }
    } else {
lijian6's avatar
lijian6 committed
1560
1561
1562
1563
        if(warp_id > kNumForwarders) {
            return;
        }

Chenggang Zhao's avatar
Chenggang Zhao committed
1564
1565
        // Combiners and coordinators
        // RDMA symmetric layout
lijian6's avatar
lijian6 committed
1566
        auto hidden_bytes = hidden_int4 * sizeof(int4);
lijian6's avatar
lijian6 committed
1567
        auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
lijian6's avatar
lijian6 committed
1568
1569
1570
        auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1571
1572

        // NVL layouts
lijian6's avatar
lijian6 committed
1573
1574
1575
1576
        void* local_nvl_buffer = buffer_ptrs[nvl_rank];
        void* nvl_buffers[NUM_MAX_NVL_PEERS];
        #pragma unroll
        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1577
            nvl_buffers[i] = buffer_ptrs[i];
lijian6's avatar
lijian6 committed
1578
1579
1580
1581
1582
        auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
Chenggang Zhao's avatar
Chenggang Zhao committed
1583
1584

        // Combiner warp synchronization
lijian6's avatar
lijian6 committed
1585
        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
1586
        __shared__ volatile bool forwarder_retired[kNumForwarders];
lijian6's avatar
lijian6 committed
1587
        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
1588
        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
lijian6's avatar
lijian6 committed
1589

Chenggang Zhao's avatar
Chenggang Zhao committed
1590
1591
1592
        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
            // Receive from NVL ranks and forward to RDMA ranks
            // NOTES: this part is using "large warps" for each RDMA ranks
lijian6's avatar
lijian6 committed
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
            const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
            const auto sub_warp_id   = target_rank % kNumWarpsPerForwarder;
            auto send_buffer         = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
            // auto sync_large_warp     = [=]() {
            //     if(kNumWarpsPerForwarder == 1) {
            //         syncwarp();
            //     } else {
            //         // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
            //         // __syncthreads();
            //         syncwarp();
            //     }
            // };
            auto sync_large_warp = [=](const int iter, const int mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1606
                if (kNumWarpsPerForwarder == 1) {
lijian6's avatar
lijian6 committed
1607
                    syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1608
                } else {
lijian6's avatar
lijian6 committed
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
                        // LDS index to store for sync
                        int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        //reset index in the LDS to avoid race condition due to warp scheduling
                        int reset_idx =         dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        // // if (lane_id==0)
                        // //     printf("rank %d dst_rdma_rank %d iter %d  warp_id %d  val %d\n", rank, dst_rdma_rank, iter, warp_id, sync_large_warp_counters[lds_dst_rdma_rank]);
                        auto start_time = wall_clock64();
                        if (lane_id == 0){
                            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
                        }
                        syncwarp();
                        //The while(...) loop polls the counter until all warps have arrived
                        if (lane_id == 0){
                            while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
                                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                                    printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
                                    trap();
                                }
lijian6's avatar
lijian6 committed
1627
1628
                            }
                        }
lijian6's avatar
lijian6 committed
1629
1630
1631
1632
1633
                        syncwarp();
                        if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
                            sync_large_warp_counters[reset_idx] = 0;
                        }
                        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1634
1635
                }
            };
lijian6's avatar
lijian6 committed
1636
            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
1637

lijian6's avatar
lijian6 committed
1638
1639
            // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
lijian6's avatar
lijian6 committed
1640
            nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
lijian6's avatar
lijian6 committed
1641
            nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
1642
1643
1644
1645
            nvl_channel_head.advance(dst_rdma_rank);
            nvl_channel_tail.advance(dst_rdma_rank);

            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1646
1647
1648
1649
1650
            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
            // sync_forwarder_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1651
1652
1653

            // Get count and cached head
            int cached_nvl_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1654
1655
            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1656
1657
1658
1659
1660
            num_tokens_to_combine -= num_tokens_prefix;
            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;

            // Iterate over all tokens and combine by chunks
lijian6's avatar
lijian6 committed
1661
            for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1662
                // Check destination queue emptiness, or wait a buffer to be released
lijian6's avatar
lijian6 committed
1663
                auto token_end_idx      = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
Chenggang Zhao's avatar
Chenggang Zhao committed
1664
                auto num_chunked_tokens = token_end_idx - token_start_idx;
lijian6's avatar
lijian6 committed
1665
                auto start_time         = wall_clock64();
lijian6's avatar
lijian6 committed
1666
1667
1668
1669
1670
1671
                while(sub_warp_id == 0 and lane_id == 0) {
                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
                    // Here, `token_start_idx` is the actual tail
                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));

                    if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
1672
1673
1674
                        break;

                    // Timeout check
lijian6's avatar
lijian6 committed
1675
1676
1677
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
                                channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
1678
1679
1680
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1681
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1682
                sync_large_warp(token_start_idx, 0);
lijian6's avatar
lijian6 committed
1683

Chenggang Zhao's avatar
Chenggang Zhao committed
1684
                // Combine and write to the RDMA buffer
lijian6's avatar
lijian6 committed
1685
                for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1686
                    // Read expected head
lijian6's avatar
lijian6 committed
1687
                    EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1688
                    int expected_head = -1;
lijian6's avatar
lijian6 committed
1689
1690
                    if(lane_id < NUM_MAX_NVL_PEERS)
                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1691
1692

                    // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1693
                    start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1694
1695
                    while(cached_nvl_channel_tail_idx <= expected_head) {
                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1696
1697

                        // Timeout check
lijian6's avatar
lijian6 committed
1698
1699
1700
                        if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
                            printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
                                    channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1701
1702
1703
1704
1705
                            trap();
                        }
                    }

                    // Combine current token
lijian6's avatar
lijian6 committed
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
                    auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
                    combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
                                                                                    expected_head, lane_id,
                                                                                    hidden_int4, num_topk,
                                                                                    reinterpret_cast<int4*>(shifted),
                                                                                    reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
                                                                                    num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1716
1717

                    // Update head
lijian6's avatar
lijian6 committed
1718
1719
1720
1721
                    if(lane_id < NUM_MAX_NVL_PEERS) {
                        expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
                                        : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1722
                }
lijian6's avatar
lijian6 committed
1723
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1724
                sync_large_warp(token_start_idx, 1);
lijian6's avatar
lijian6 committed
1725

Chenggang Zhao's avatar
Chenggang Zhao committed
1726
                // Issue RDMA send
lijian6's avatar
lijian6 committed
1727
1728
                if(sub_warp_id == kNumWarpsPerForwarder - 1) {
                    if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1729
                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
1730
                        rocshmem::rocshmem_ctx_schar_put_nbi_wave(
lijian6's avatar
lijian6 committed
1731
1732
1733
1734
1735
1736
1737
1738
1739
                            ctx,
                            rdma_channel_data.recv_buffer(rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            rdma_channel_data.send_buffer(dst_rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            num_chunked_tokens * num_bytes_per_rdma_token,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

                        rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1740
1741
1742
1743
1744
                    } else {
                        memory_fence();
                    }

                    // Write new RDMA tail
lijian6's avatar
lijian6 committed
1745
                    syncwarp();
lijian6's avatar
lijian6 committed
1746
                    if(lane_id == 0) {
lijian6's avatar
lijian6 committed
1747
1748
1749
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1750
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1751
1752
1753
1754
                }
            }

            // Retired
lijian6's avatar
lijian6 committed
1755
            syncwarp();
lijian6's avatar
lijian6 committed
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
            if(lane_id == 0) {
                forwarder_retired[target_rank] = true;
            }
        } else if (warp_role == WarpRole::kRDMACoordinator) {
            // Coordinator
            // Sync shared memory status
            // sync_forwarder_smem();
            __syncthreads();
            constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

            int last_nvl_head[kNumRDMARanks] = {0};
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
                // Retired
                if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
                    break;

                {
                    // Find minimum head for NVL ranks
                    #pragma unroll
                    for(int i = 0; i < kNumRDMARanks; ++i) {
                        int min_head = std::numeric_limits<int>::max();
                        #pragma unroll
                        for(int j = 0; j < num_warps_per_rdma_rank; ++j)
                            if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);

                        if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
                        }
                    }
                }

                // Nanosleep and let other warps work
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
            }
        } else if(warp_role == WarpRole::kRDMAReceiver) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1795
1796
            // Receive from RDMA ranks and write to the output tensor
            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1797
            EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lijian6's avatar
lijian6 committed
1798
1799
1800
1801
            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1802
1803
1804

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1805
            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1806
1807
1808

            // Iterate over all tokens and combine
            int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1809
            for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1810
                // Read expected head
lijian6's avatar
lijian6 committed
1811
                EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1812
                int expected_head = -1;
lijian6's avatar
lijian6 committed
1813
1814
1815
1816
                if(lane_id < kNumRDMARanks) {
                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
                    (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
                                        : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1817
1818
1819
                }

                // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1820
                auto start_time = wall_clock64();
Chenggang Zhao's avatar
Chenggang Zhao committed
1821
                while (cached_channel_tail_idx <= expected_head) {
lijian6's avatar
lijian6 committed
1822
                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
Chenggang Zhao's avatar
Chenggang Zhao committed
1823
1824

                    // Timeout check
lijian6's avatar
lijian6 committed
1825
1826
1827
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
                                channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1828
1829
1830
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1831
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1832
1833

                // Combine current token
lijian6's avatar
lijian6 committed
1834
1835
1836
1837
1838
1839
1840
1841
                auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
                combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
                                                                            expected_head, lane_id,
                                                                            hidden_int4, num_topk,
                                                                            combined_x + token_idx * hidden_int4,
                                                                            combined_topk_weights + token_idx * num_topk,
                                                                            num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1842
1843
1844
            }

            // Retired
lijian6's avatar
lijian6 committed
1845
            syncwarp();
lijian6's avatar
lijian6 committed
1846
1847
1848
1849
            if(lane_id == 0) {
                rdma_receiver_retired[target_rank] = true;
            }
        } else if(warp_role == WarpRole::kNVLCoordinator) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1850
1851
            // Coordinator
            // Sync shared memory status
lijian6's avatar
lijian6 committed
1852
1853
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1854
1855
            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1856
            int last_rdma_head               = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
1857
            int last_nvl_head[kNumRDMARanks] = {0};
lijian6's avatar
lijian6 committed
1858
1859
            int dst_rdma_rank                = lane_id < kNumRDMARanks ? lane_id : 0;
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
lijian6's avatar
lijian6 committed
1860
1861
1862
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1863
                // Retired
lijian6's avatar
lijian6 committed
1864
                if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
Chenggang Zhao's avatar
Chenggang Zhao committed
1865
1866
1867
                    break;

                // Find minimum head for RDMA ranks
lijian6's avatar
lijian6 committed
1868
                {
Chenggang Zhao's avatar
Chenggang Zhao committed
1869
                    int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1870
1871
1872
    #pragma unroll
                    for(int i = 0; i < kNumRDMAReceivers; ++i)
                        if(not rdma_receiver_retired[i])
lijian6's avatar
lijian6 committed
1873
                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
lijian6's avatar
lijian6 committed
1874
1875

                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lijian6's avatar
lijian6 committed
1876
1877
1878
1879
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

1880
1881
                        last_rdma_head = min_head;
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1882
1883
1884
                }

                // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1885
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1886
1887
1888
            }
        }
    }
lijian6's avatar
lijian6 committed
1889
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1890
1891
}

lijian6's avatar
lijian6 committed
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1902
    constexpr int kNumCombineForwarderWarps = 8;
lijian6's avatar
lijian6 committed
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921

#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                                        \
    {                                                                                              \
        auto combine_func =                                                                        \
            low_latency_mode                                                                       \
                ? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>           \
                : combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>;         \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights,       \
            is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights,            \
            reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1),        \
            combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,           \
            num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr,                    \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

lijian6's avatar
lijian6 committed
1922
    int num_rdma_ranks           = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1923
    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
lijian6's avatar
lijian6 committed
1924
1925
    int num_forwarder_warps      = num_rdma_ranks * num_warps_per_forwarder;
    EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
lijian6's avatar
lijian6 committed
1926
    EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
1927
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
lijian6's avatar
lijian6 committed
1928
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
lijian6's avatar
lijian6 committed
1929
    EP_HOST_ASSERT(type == HIP_R_16BF);
Chenggang Zhao's avatar
Chenggang Zhao committed
1930

lijian6's avatar
lijian6 committed
1931
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1932
1933
1934
    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
1935

Chenggang Zhao's avatar
Chenggang Zhao committed
1936
1937
1938
} // namespace internode

} // namespace deep_ep
lijian6's avatar
lijian6 committed
1939

lijian6's avatar
lijian6 committed
1940
1941
1942
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
lijian6's avatar
lijian6 committed
1943
1944

#endif