internode.cu 105 KB
Newer Older
lijian6's avatar
lijian6 committed
1
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
2
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
3
#include "configs.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
#include "launch.cuh"
#include "utils.cuh"
lijian6's avatar
lijian6 committed
6
7
8
9
10
11

#ifndef DISABLE_ROCSHMEM

#include <rocshmem/rocshmem.hpp>

// TODO: fix unroll warnings
lijian6's avatar
lijian6 committed
12
13
14
15
16
// #ifdef __clang__
// #pragma clang diagnostic push
// #pragma clang diagnostic ignored "-Wpass-failed"
// #pragma clang diagnostic ignored "-Wdeprecated-volatile"
// #endif // __clang__
Chenggang Zhao's avatar
Chenggang Zhao committed
17
18
19
20
21

namespace deep_ep {

namespace internode {

lijian6's avatar
lijian6 committed
22
extern rocshmem::rocshmem_team_t cpu_rdma_team;
Chenggang Zhao's avatar
Chenggang Zhao committed
23
24
25
26
27
28
29
30
31

struct SourceMeta {
    int src_rdma_rank, is_token_in_nvl_rank_bits;

    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");

    __forceinline__ SourceMeta() = default;

    // TODO: faster encoding
lijian6's avatar
lijian6 committed
32
33
    __device__ __forceinline__ SourceMeta(int rdma_rank, const bool *is_token_in_nvl_ranks) {
        src_rdma_rank             = rdma_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
34
        is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0];
lijian6's avatar
lijian6 committed
35
36
#pragma unroll
        for (int i = 1; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
37
38
39
40
41
42
43
44
45
46
47
48
            is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i;
    }

    __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const {
        return (is_token_in_nvl_rank_bits >> nvl_rank) & 1;
    }
};

int get_source_meta_bytes() {
    return sizeof(SourceMeta);
}

lijian6's avatar
lijian6 committed
49
50
51
52
53
54
55
56
__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token(int hidden_int4,
                                                                     int num_scales,
                                                                     int num_topk_idx,
                                                                     int num_topk_weights) {
    return static_cast<int>(ALIGN(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) +
                                      num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
                                      num_topk_weights * sizeof(float),
                                  sizeof(int4)));
Chenggang Zhao's avatar
Chenggang Zhao committed
57
58
}

lijian6's avatar
lijian6 committed
59
60
61
__host__ __device__ __forceinline__ std::pair<int, int>
get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                    int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) {
lijian6's avatar
lijian6 committed
62
    // Return `int32_t` offset and count to clean
lijian6's avatar
lijian6 committed
63
64
65
66
    return {(get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) *
             num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) /
                sizeof(int),
            (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms};
Chenggang Zhao's avatar
Chenggang Zhao committed
67
}
lijian6's avatar
lijian6 committed
68
69
70
71
__host__ __device__ __forceinline__ std::pair<int, int>
get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
                   int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens,
                   int num_sms) {
Chenggang Zhao's avatar
Chenggang Zhao committed
72
    // Return `int32_t` offset and to clean
lijian6's avatar
lijian6 committed
73
74
    EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0,
                              "Invalid size of `SourceMeta`");
Chenggang Zhao's avatar
Chenggang Zhao committed
75
    return {
lijian6's avatar
lijian6 committed
76
77
78
79
80
81
        (num_nvl_recv_buffer_tokens *
         (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) +
          num_topk_weights * sizeof(float) + sizeof(SourceMeta)) *
         num_nvl_ranks * num_sms) /
            sizeof(int),
        num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms,
Chenggang Zhao's avatar
Chenggang Zhao committed
82
83
84
85
    };
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
86
87
__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
                                                       const int nvl_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
88
89
90
91
    return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank;
}

template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
92
93
94
95
__forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
    // NOTE: shmem_device_barrier_all() might be an issue as
    // it doesn't follow OpenSHMEM specification on ROCm
lijian6's avatar
lijian6 committed
96
97
98
    kLowLatencyMode
        ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
        : rocshmem::rocshmem_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
99
100
101
102
}

template <bool kLowLatencyMode, int kNumRDMARanks>
__global__ void
lijian6's avatar
lijian6 committed
103
104
105
106
107
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int num_experts, const bool *is_token_in_rank, int num_tokens, int num_channels,
                int expert_alignment, const int rdma_clean_offset, const int rdma_num_int_clean,
Chenggang Zhao's avatar
Chenggang Zhao committed
108
                const int nvl_clean_offset, const int nvl_num_int_clean,
lijian6's avatar
lijian6 committed
109
110
111
112
113
114
115
116
                int *rdma_channel_prefix_matrix, int *recv_rdma_rank_prefix_sum,
                int *gbl_channel_prefix_matrix, int *recv_gbl_rank_prefix_sum,
                void *rdma_buffer_ptr, void **buffer_ptrs, int **barrier_signal_ptrs, int rank,
                const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id     = static_cast<int>(blockIdx.x);
    auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize,
         lane_id     = get_lane_id();
    auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
117
118

    auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
119
120
    auto num_rdma_experts = num_experts / kNumRDMARanks,
         num_nvl_experts  = num_rdma_experts / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
121
122
123

    if (sm_id == 0) {
        // Communication with others
lijian6's avatar
lijian6 committed
124
        // Global barrier: the first warp do intra-node sync, the second warp do internode sync
Chenggang Zhao's avatar
Chenggang Zhao committed
125
126
        EP_DEVICE_ASSERT(num_warps > 1);
        EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
lijian6's avatar
lijian6 committed
127
128
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
129

lijian6's avatar
lijian6 committed
130
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
131
132
        __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
133
        // Send numbers of tokens per rank/expert to RDMA ranks
lijian6's avatar
lijian6 committed
134
135
136
        auto rdma_buffer_ptr_int        = reinterpret_cast<int *>(rdma_buffer_ptr);
        auto rdma_recv_num_tokens_mixed = SymBuffer<int>(
            rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
Chenggang Zhao's avatar
Chenggang Zhao committed
137
138

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
139
140
        EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <=
                                  rdma_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
141
142
143
144
145
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;

        // Copy to send buffer
        for (int i = thread_id; i < num_ranks; i += num_threads)
lijian6's avatar
lijian6 committed
146
147
            rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] =
                num_tokens_per_rank[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
148
        for (int i = thread_id; i < num_experts; i += num_threads)
lijian6's avatar
lijian6 committed
149
150
151
            rdma_recv_num_tokens_mixed.send_buffer(
                i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] =
                num_tokens_per_expert[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
152
        if (thread_id < kNumRDMARanks)
lijian6's avatar
lijian6 committed
153
154
155
            rdma_recv_num_tokens_mixed.send_buffer(
                thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] =
                num_tokens_per_rdma_rank[thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
156
157
158
159
160
        __syncthreads();

        // Issue send
        // TODO: more light fence or barrier or signaling
        // TODO: overlap EP barrier and NVL cleaning
lijian6's avatar
lijian6 committed
161
162
163
164
165
166
        if (thread_id < kNumRDMARanks) {
            rocshmem::rocshmem_int_put_nbi(
                rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
                rdma_recv_num_tokens_mixed.send_buffer(thread_id),
                NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
                translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
167
        }
alpha-baby's avatar
alpha-baby committed
168
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
169
        if (thread_id == 0)
lijian6's avatar
lijian6 committed
170
171
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

Chenggang Zhao's avatar
Chenggang Zhao committed
172
173
174
175
176
        __syncthreads();

        // NVL buffers
        auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr;
        auto nvl_recv_buffer = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
177
178
179
180
181
182
183
184
185
186
        auto nvl_reduced_num_tokens_per_expert =
            Buffer<int>(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer);
        auto nvl_send_num_tokens_per_rank =
            AsymBuffer<int>(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_send_num_tokens_per_expert =
            AsymBuffer<int>(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_rank =
            AsymBuffer<int>(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS);
        auto nvl_recv_num_tokens_per_expert =
            AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
Chenggang Zhao's avatar
Chenggang Zhao committed
187
188

        // Clean up for later data dispatch
lijian6's avatar
lijian6 committed
189
190
191
192
193
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
        EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes +
                                      nvl_send_num_tokens_per_rank.total_bytes +
                                      nvl_send_num_tokens_per_expert.total_bytes <=
                                  nvl_clean_offset * sizeof(int));
Chenggang Zhao's avatar
Chenggang Zhao committed
194
195
196
197
198
199
200
201
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;

        // Reduce number of tokens per expert into the NVL send buffer
        // TODO: may use NVSHMEM reduction
        EP_DEVICE_ASSERT(num_rdma_experts <= num_threads);
        if (thread_id < num_rdma_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
202
203
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
204
205
206
207
208
209
210
211
                sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id];
            nvl_reduced_num_tokens_per_expert[thread_id] = sum;
        }
        __syncthreads();

        // Reduce RDMA received tokens
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
212
213
214
215
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i) {
                sum +=
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
Chenggang Zhao's avatar
Chenggang Zhao committed
216
217
                recv_rdma_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
218
219
            while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
220
221
222
223
224
225
            *moe_recv_rdma_counter_mapped = sum;
        }

        // Send numbers of tokens per rank/expert to NVL ranks
        EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads);
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
226
227
228
229
230
231
232
#pragma unroll
            for (int i = 0; i < kNumRDMARanks; ++i)
                nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] =
                    rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id];
            for (int i = 0; i < num_nvl_experts; ++i)
                nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] =
                    nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
233
        }
lijian6's avatar
lijian6 committed
234
235
        memory_fence();
        __syncthreads();
236
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
237
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
238

lijian6's avatar
lijian6 committed
239
        // Reduce number of tokens per rank/expert
Chenggang Zhao's avatar
Chenggang Zhao committed
240
241
242
        EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
        if (thread_id == 0) {
            int sum = 0;
lijian6's avatar
lijian6 committed
243
            for (int i = 0; i < num_ranks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
244
245
246
247
                int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS;
                sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
                recv_gbl_rank_prefix_sum[i] = sum;
            }
lijian6's avatar
lijian6 committed
248
249
            while (ld_volatile_global(moe_recv_counter_mapped) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
250
251
252
253
            *moe_recv_counter_mapped = sum;
        }
        if (thread_id < num_nvl_experts) {
            int sum = 0;
lijian6's avatar
lijian6 committed
254
255
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
256
257
                sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
lijian6's avatar
lijian6 committed
258
259
            while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
                ;
Chenggang Zhao's avatar
Chenggang Zhao committed
260
261
262
263
            moe_recv_expert_counter_mapped[thread_id] = sum;
        }

        // Finally barrier
lijian6's avatar
lijian6 committed
264
265
266
267
        __syncthreads();
        if (thread_id == kWarpSize)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);

268
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
269
270
271
272
273
    } else {
        // Calculate meta data
        int dst_rdma_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
274
275
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
276
277
278

            // Iterate over tokens
            int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0};
lijian6's avatar
lijian6 committed
279
280
281
282
283
284
285
286
287
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize) {
                EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t),
                                          "Invalid number of NVL peers");
                auto is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t *>(
                    is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS);
                auto is_token_in_rank_values =
                    reinterpret_cast<const bool *>(&is_token_in_rank_uint64);
#pragma unroll
                for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
288
289
290
291
292
293
                    per_nvl_rank_count[j] += is_token_in_rank_values[j];
                total_count += (is_token_in_rank_uint64 != 0);
            }

            // Warp reduce
            total_count = warp_reduce_sum(total_count);
lijian6's avatar
lijian6 committed
294
295
#pragma unroll
            for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
296
297
298
                per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]);

            // Write into channel matrix
lijian6's avatar
lijian6 committed
299
300
301
302
303
304
            if (lane_id == 0) {
#pragma unroll
                for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                    gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) *
                                                  num_channels +
                                              channel_id] = per_nvl_rank_count[i];
Chenggang Zhao's avatar
Chenggang Zhao committed
305
306
307
308
309
310
                rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count;
            }
        }

        // Calculate prefix sum
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
311
        if (thread_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
312
            auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels;
lijian6's avatar
lijian6 committed
313
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
314
315
316
                prefix_row[i] += prefix_row[i - 1];
        }

lijian6's avatar
lijian6 committed
317
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
318
        if (thread_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
319
320
321
            auto prefix_row = gbl_channel_prefix_matrix +
                              (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels;
            for (int i = 1; i < num_channels; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
322
323
324
325
326
                prefix_row[i] += prefix_row[i - 1];
        }
    }
}

lijian6's avatar
lijian6 committed
327
328
329
330
331
332
333
334
335
336
337
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_rdma_rank, int *moe_recv_rdma_counter_mapped,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int num_experts, const bool *is_token_in_rank, int num_tokens,
                     int num_channels, int hidden_int4, int num_scales, int num_topk,
                     int expert_alignment, int *rdma_channel_prefix_matrix,
                     int *recv_rdma_rank_prefix_sum, int *gbl_channel_prefix_matrix,
                     int *recv_gbl_rank_prefix_sum, void *rdma_buffer_ptr,
                     int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                     int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                     hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
338
                     bool low_latency_mode) {
lijian6's avatar
lijian6 committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                \
    {                                                                                              \
        auto notify_dispatch_func = low_latency_mode ? notify_dispatch<true, num_rdma_ranks>       \
                                                     : notify_dispatch<false, num_rdma_ranks>;     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, notify_dispatch_func, num_tokens_per_rank, moe_recv_counter_mapped, num_ranks,   \
            num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, num_tokens_per_expert,         \
            moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, num_tokens,             \
            num_channels, expert_alignment, rdma_clean_meta.first, rdma_clean_meta.second,         \
            nvl_clean_meta.first, nvl_clean_meta.second, rdma_channel_prefix_matrix,               \
            recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,        \
            rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, cpu_rdma_team);               \
    }                                                                                              \
    break

    constexpr int kNumThreads    = 256;
    const auto    num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
356
357

    // Get clean meta
lijian6's avatar
lijian6 committed
358
359
360
361
362
363
364
365
366
367
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());

    // Launch kernel
    SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream);
    SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

// At most 8 RDMA ranks to be sent
constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) {
    return num_rdma_ranks < 8 ? num_rdma_ranks : 8;
}

lijian6's avatar
lijian6 committed
382
383
384
385

template <bool kLowLatencyMode,
          int kNumRDMARanks,
          bool kCachedMode,
lijian6's avatar
lijian6 committed
386
387
          int kNumDispatchRDMASenderWarps,
          int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks)>
lijian6's avatar
lijian6 committed
388
389
390
391
392
393
394
395
396
397
398
399
400
__global__ void __launch_bounds__(((1 + NUM_MAX_NVL_PEERS) * kWarpSize), 1)
dispatch(int4 *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
        SourceMeta *recv_src_meta, const int4 *x, const float *x_scales,
        const int64_t *topk_idx, const float *topk_weights, int *send_rdma_head,
        int *send_nvl_head, int *recv_rdma_channel_prefix_matrix,
        int *recv_gbl_channel_prefix_matrix, const int *rdma_channel_prefix_matrix,
        const int *recv_rdma_rank_prefix_sum, const int *gbl_channel_prefix_matrix,
        const int *recv_gbl_rank_prefix_sum, const bool *is_token_in_rank, int num_tokens,
        int hidden_int4, int num_scales, int num_topk, int num_experts, int scale_token_stride,
        int scale_hidden_stride, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
        int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
        int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
        int num_ranks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
401
    enum class WarpRole {
lijian6's avatar
lijian6 committed
402
403
404
405
406
        kRDMASender,            // 从x写入到RDMA发送缓存
        kRDMASenderCoordinator, // 从RDMA发送缓存写入到远端rdma_rank接收缓存
        kRDMAAndNVLForwarder,   // 从RDMA接收缓存转写到ipc nvl缓存
        kForwarderCoordinator,  // 向远端RDMA确认接收
        kNVLReceivers           // 从nvl缓存写入到recv_x
Chenggang Zhao's avatar
Chenggang Zhao committed
407
408
    };

lijian6's avatar
lijian6 committed
409
410
411
412
413
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
lijian6's avatar
lijian6 committed
414
415
416
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
Chenggang Zhao's avatar
Chenggang Zhao committed
417
418
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;

lijian6's avatar
lijian6 committed
419
420
421
    EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers");
    EP_DEVICE_ASSERT(num_warps == 1 + NUM_MAX_NVL_PEERS);

Chenggang Zhao's avatar
Chenggang Zhao committed
422
    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
lijian6's avatar
lijian6 committed
423
424
425
426
427
428
429
430
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASender, -1};
            } else if(warp_id == kNumDispatchRDMASenderWarps) {
                return {WarpRole::kRDMASenderCoordinator, -1};
            }
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            if(warp_id < NUM_MAX_NVL_PEERS) {
Chenggang Zhao's avatar
Chenggang Zhao committed
431
432
433
434
435
                return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
            } else {
                return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS};
            }
        } else {
lijian6's avatar
lijian6 committed
436
            return {WarpRole::kNVLReceivers, (warp_id + channel_id + 1) % NUM_MAX_NVL_PEERS};
Chenggang Zhao's avatar
Chenggang Zhao committed
437
438
        }
    }();
lijian6's avatar
lijian6 committed
439

lijian6's avatar
lijian6 committed
440
    auto warp_role = role_meta.first;
Chenggang Zhao's avatar
Chenggang Zhao committed
441
442
443
    auto target_rank = role_meta.second; // Not applicable for RDMA senders

    // RDMA symmetric layout
lijian6's avatar
lijian6 committed
444
445
446
447
448
449
    auto hidden_bytes             = hidden_int4 * sizeof(int4);
    auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk);
    auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
450
451

    // NVL buffer layouts
lijian6's avatar
lijian6 committed
452
    // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
Chenggang Zhao's avatar
Chenggang Zhao committed
453
    void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
lijian6's avatar
lijian6 committed
454
    int rs_wr_rank = 0, ws_rr_rank = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
455
    if (warp_role == WarpRole::kRDMAAndNVLForwarder)
lijian6's avatar
lijian6 committed
456
        rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
457
    if (warp_role == WarpRole::kNVLReceivers)
lijian6's avatar
lijian6 committed
458
        rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
459
460

    // Allocate buffers
lijian6's avatar
lijian6 committed
461
462
463
464
465
466
467
468
469
    auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_x_scales = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_idx = AsymBuffer<int>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_topk_weights = AsymBuffer<float>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_start = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_prefix_end = AsymBuffer<int>(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
    auto nvl_channel_head = AsymBuffer<int>(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr);
    auto nvl_channel_tail = AsymBuffer<int>(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
470
471

    // RDMA sender warp synchronization
lijian6's avatar
lijian6 committed
472
473
474
    __shared__ volatile int rdma_send_next_token_idx;
    __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks];
    __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks];
475

lijian6's avatar
lijian6 committed
476
477
    // NVL and RDMA coordinate Forward warp synchronization
    __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
478
    __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS];
lijian6's avatar
lijian6 committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

    // Place the main logic of your kernel here, using the parameters above.
    if(warp_role == WarpRole::kRDMASender) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先获取当前通道的任务范围,然后清理共享内存,接着计算并发送本通道中的令牌数量。
        然后,它遍历所有的令牌,读取每个令牌的RDMA秩的存在性,获取顺序锁,计算下一个尾部位置,存储RDMA头部,更新最后一个令牌尾部,释放顺序锁,并广播尾部位置。
        最后,它复制相关的数据到对称发送缓冲区。

        kRDMASender主要目的是将发送信息x, x_scale,source_meta, topk_idx, topk_weight等信息填充进入rdma发送缓存,
        期间要同步warp直接对token的依序操作,以及和kForwarderCoordinator, kRDMASenderCoordinator内存同步。
        同时在复制操作时, 使用ld.global.nc.L1::no_allocate.L2::256B, st.global.L1::no_allocate减少L1/L2缓存使用。
        */
        // 获取任务范围
        int token_start_idx, token_end_idx;
        get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

        // 清理共享内存
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA秩数量");
        if(warp_id == 0 && lane_id == 0) {
            rdma_send_next_token_idx = token_start_idx;
lijian6's avatar
lijian6 committed
500
        }
lijian6's avatar
lijian6 committed
501
502
503
        if(warp_id == 0 && lane_id < kNumRDMARanks) {
            rdma_send_channel_tail[lane_id]      = 0;
            rdma_send_channel_next_tail[lane_id] = 0;
lijian6's avatar
lijian6 committed
504
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
505

lijian6's avatar
lijian6 committed
506
507
508
509
510
511
        // 发送本通道中的令牌数量,通过 `-value - 1` 表示
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= kWarpSize, "无效的NVL对等体数量");
        // 对于每个目标RDMA秩,以warp为单位进行迭代。计算发送缓冲区的值,并存储在rdma_channel_meta.send_buffer中
        // 用于填充rdma_channel_meta.send_buffer本节点发送到远端rank, rdma_rank的起始index和结束index
        for(int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) {
            auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
512
            if (lane_id < NUM_MAX_NVL_PEERS) {
lijian6's avatar
lijian6 committed
513
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
514
            } else if (lane_id < NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
515
                dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
516
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2) {
lijian6's avatar
lijian6 committed
517
                dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
518
            } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) {
lijian6's avatar
lijian6 committed
519
                dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
520
            }
lijian6's avatar
lijian6 committed
521

lijian6's avatar
lijian6 committed
522
523
524
            syncwarp();
            if (dst_rdma_rank != rdma_rank) {
                rocshmem::rocshmem_ctx_int_put_nbi_wave(
lijian6's avatar
lijian6 committed
525
526
527
                ctx, rdma_channel_meta.recv_buffer(rdma_rank),
                rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2,
                translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
528
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
529
        }
lijian6's avatar
lijian6 committed
530
        rocshmem::rocshmem_ctx_quiet(ctx);
lijian6's avatar
lijian6 committed
531
532
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
533

lijian6's avatar
lijian6 committed
534
        // 遍历令牌并复制到缓冲区
Chenggang Zhao's avatar
Chenggang Zhao committed
535
        int64_t token_idx;
lijian6's avatar
lijian6 committed
536
537
538
539
        int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1;
        auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id);
        for(token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) {
            // 读取RDMA秩的存在性
Chenggang Zhao's avatar
Chenggang Zhao committed
540
            uint64_t is_token_in_rank_uint64 = 0;
lijian6's avatar
lijian6 committed
541
542
543
            if(lane_id < kNumRDMARanks) {
                is_token_in_rank_uint64 = *reinterpret_cast<const uint64_t*>(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
544

lijian6's avatar
lijian6 committed
545
546
547
548
            // 获得处理数据的自旋锁,获得锁后才会处理一些数据信息
            while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
                // 等待
            }
lijian6's avatar
lijian6 committed
549
            syncwarp();
550

lijian6's avatar
lijian6 committed
551
            // 获取下一个尾部位置
lijian6's avatar
lijian6 committed
552
            int rdma_tail_idx = -1;
lijian6's avatar
lijian6 committed
553
            if(is_token_in_rank_uint64 != 0) {
lijian6's avatar
lijian6 committed
554
                rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++;
lijian6's avatar
lijian6 committed
555
556
557
558
559

                // 与kForwarderCoordinator相互配合,调节发送数据的频率
                while(rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) {
                    cached_rdma_channel_head = static_cast<int>(ld_volatile_global(rdma_channel_head.buffer(lane_id)));
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
560
            }
lijian6's avatar
lijian6 committed
561
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
562

lijian6's avatar
lijian6 committed
563
564
            // 存储RDMA头部以供合并
            if(lane_id < kNumRDMARanks && !kCachedMode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
565
                send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx;
lijian6's avatar
lijian6 committed
566
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
567

lijian6's avatar
lijian6 committed
568
569
570
571
            // 更新最后一个令牌尾部
            if(last_rdma_tail_idx >= 0) {
                st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
            }
lijian6's avatar
lijian6 committed
572
573
            last_rdma_tail_idx = rdma_tail_idx;

lijian6's avatar
lijian6 committed
574
575
576
577
            // 释放顺序锁
            if(lane_id == 0) {
                rdma_send_next_token_idx += 1;
            }
lijian6's avatar
lijian6 committed
578

lijian6's avatar
lijian6 committed
579
            // 广播尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
580
            SourceMeta src_meta;
lijian6's avatar
lijian6 committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
            int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks];
            void* dst_send_buffers[kNumTopkRDMARanks];
            /*
            该for循环主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作
            */
            #pragma unroll
            for(int i = 0, slot_idx; i < kNumRDMARanks; ++i) {
                // 使用__shfl_sync函数在warp内同步并广播rdma_tail_idx的值
                if((slot_idx = shfl_sync(rdma_tail_idx, i)) >= 0) {
                    // warp 所有线程参与,rdma_tail_idx默认为-1, 只有对应rdma rank需要发送时, rdma_tail_idx才会>=0
                    //  计算slot_idx在接收缓冲区中的位置
                    slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens;

                    // 存储当前RDMA秩到topk_ranks数组中
                    topk_ranks[num_topk_ranks] = i;

                    // 广播is_token_in_rank_uint64的值到所有线程,并解释为布尔数组
lijian6's avatar
lijian6 committed
598
                    auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i);
lijian6's avatar
lijian6 committed
599
600
601
602
                    auto recv_is_token_in_rank_values = reinterpret_cast<const bool*>(&recv_is_token_in_rank_uint64);

                    // 如果当前lane_id等于num_topk_ranks,则更新src_meta
                    if(lane_id == num_topk_ranks) {
lijian6's avatar
lijian6 committed
603
                        src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values);
lijian6's avatar
lijian6 committed
604
605
606
607
608
                    }

                    // 计算目标发送缓冲区的地址,并存储在dst_send_buffers数组中
                    // 获取到发送地址, num_topk_ranks-1 是需要发送的ranks数
                    dst_send_buffers[num_topk_ranks++] = reinterpret_cast<uint8_t*>(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token;
lijian6's avatar
lijian6 committed
609
                }
lijian6's avatar
lijian6 committed
610
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
611
612
            EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks);

lijian6's avatar
lijian6 committed
613
614
615
616
617
618
            // 复制 `x` 到对称发送缓冲区
            auto st_broadcast = [=](const int key, const int4& value) {
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<int4*>(dst_send_buffers[j]) + key, value);
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
619
            };
lijian6's avatar
lijian6 committed
620
621
622
623
624
            UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast);
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<int4*>(dst_send_buffers[i]) + hidden_int4;
            }
lijian6's avatar
lijian6 committed
625

lijian6's avatar
lijian6 committed
626
627
628
629
630
631
632
633
            // 复制源元数据到对称发送缓冲区
            if(lane_id < num_topk_ranks) {
                st_na_global(reinterpret_cast<SourceMeta*>(dst_send_buffers[lane_id]), src_meta);
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<SourceMeta*>(dst_send_buffers[i]) + 1;
            }
lijian6's avatar
lijian6 committed
634

lijian6's avatar
lijian6 committed
635
636
637
            // 复制 `x_scales` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_scales; i += kWarpSize) {
lijian6's avatar
lijian6 committed
638
                auto value = ld_nc_global(x_scales + token_idx * num_scales + i);
lijian6's avatar
lijian6 committed
639
640
641
642
643
644
645
646
647
648
649

                // auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
                // auto value = ld_nc_global(x_scales + offset);
#pragma unroll
                for(int j = 0; j < num_topk_ranks; ++j) {
                    st_na_global(reinterpret_cast<float*>(dst_send_buffers[j]) + i, value);
                }
            }
#pragma unroll
            for(int i = 0; i < num_topk_ranks; ++i) {
                dst_send_buffers[i] = reinterpret_cast<float*>(dst_send_buffers[i]) + num_scales;
lijian6's avatar
lijian6 committed
650
            }
651

lijian6's avatar
lijian6 committed
652
653
654
            // 复制 `topk_idx` 和 `topk_weights` 到对称发送缓冲区
#pragma unroll
            for(int i = lane_id; i < num_topk * num_topk_ranks; i += kWarpSize) {
Chenggang Zhao's avatar
Chenggang Zhao committed
655
                auto rank_idx = i / num_topk, copy_idx = i % num_topk;
lijian6's avatar
lijian6 committed
656
                auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
657
                auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
lijian6's avatar
lijian6 committed
658
659
                st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
                st_na_global(reinterpret_cast<float*>(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
660
            }
lijian6's avatar
lijian6 committed
661
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
662

lijian6's avatar
lijian6 committed
663
664
665
666
667
668
        // 结尾部分
        // 获取顺序锁
        while(lane_id == 0 && rdma_send_next_token_idx != token_idx) {
            // 等待
        }

lijian6's avatar
lijian6 committed
669
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
670

lijian6's avatar
lijian6 committed
671
672
673
674
        // 更新最后一个令牌尾部
        if(last_rdma_tail_idx >= 0) {
            st_release_cta(const_cast<int*>(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1);
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
675

lijian6's avatar
lijian6 committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        // 释放顺序锁
        if(lane_id == 0) {
            rdma_send_next_token_idx += 1;
        }
    } else if(warp_role == WarpRole::kRDMASenderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调多个线程之间的RDMA发送操作。
        它首先计算每个RDMA秩需要发送的令牌数,然后在所有RDMA秩之间循环,检查是否有令牌需要发送。
        如果有,它将计算本次需要发出的令牌数,并发出相应的RDMA发送请求。
        最后,它更新相关的尾部位置,以便下次循环时可以正确地计算需要发送的令牌数。

        kRDMASenderCoordinator使用了同sm内存一致性(ld.acquire.cta.s32),
        nvshmem内存一致性(nvshmem_fence)和原子操作(nvshmemx_signal_op),减少硬同步,提升整体效率。
        */
        if(warp_id > kNumDispatchRDMASenderWarps) {
            return;
        }
        // 确保最大接收令牌数可以被最大发送令牌数整除,以避免缓冲区分割问题
        EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
695

lijian6's avatar
lijian6 committed
696
697
698
        // 同步共享内存,确保所有线程在继续之前都达到了这一点
        // sync_rdma_sender_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
699

lijian6's avatar
lijian6 committed
700
        // 计算当前通道需要发送的令牌数
Chenggang Zhao's avatar
Chenggang Zhao committed
701
        int num_tokens_to_send = 0;
lijian6's avatar
lijian6 committed
702
        if(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
703
            num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id];
lijian6's avatar
lijian6 committed
704
705
            if(channel_id > 0)
                num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
706
707
        }

lijian6's avatar
lijian6 committed
708
        // 记录上次发出的尾部位置
Chenggang Zhao's avatar
Chenggang Zhao committed
709
        int last_issued_tail = 0;
lijian6's avatar
lijian6 committed
710
711
712
713
714
715
716
        // 当有任何RDMA秩需要发送令牌时,继续循环
        while(__any_sync(kFullWarpMask, num_tokens_to_send > 0)) {
            for(int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) {
                // 计算目标RDMA秩
                int dst_rdma_rank = (i + channel_id) % kNumRDMARanks;

                // 获取同步后的需要发送的令牌数
lijian6's avatar
lijian6 committed
717
                synced_num_tokens_to_send = shfl_sync(num_tokens_to_send, dst_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
718

lijian6's avatar
lijian6 committed
719
720
721
722
                if(synced_num_tokens_to_send == 0)
                    continue; // 如果没有令牌需要发送,则跳过

                // 读取进度
lijian6's avatar
lijian6 committed
723
                auto synced_last_issued_tail = shfl_sync(last_issued_tail, dst_rdma_rank);
lijian6's avatar
lijian6 committed
724
725
726
727
728
                auto processed_tail          = ld_acquire_cta(const_cast<const int*>(rdma_send_channel_tail + dst_rdma_rank));
                auto num_tokens_processed    = processed_tail - synced_last_issued_tail;

                // 如果处理的令牌数不等于需要发送的令牌数,并且处理的令牌数小于最大发送令牌数,则跳过
                if(num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
729
730
                    continue;

lijian6's avatar
lijian6 committed
731
732
733
734
735
736
                // 计算本次需要发出的令牌数
                auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens);
                EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 && num_tokens_to_issue <= synced_num_tokens_to_send);

                // 发出RDMA发送请求
                if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
737
                    auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
738
                    EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
lijian6's avatar
lijian6 committed
739
740
741
742
743
744
745
746
747
                    rocshmem::rocshmem_ctx_schar_put_nbi_wave(
                        ctx,
                        rdma_channel_data.recv_buffer(rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        rdma_channel_data.send_buffer(dst_rdma_rank) +
                            dst_slot_idx * num_bytes_per_rdma_token,
                        num_bytes_per_rdma_token * num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
                    rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
748
                } else {
lijian6's avatar
lijian6 committed
749
                    // 对于本地RDMA秩,使用较轻的内存屏障
Chenggang Zhao's avatar
Chenggang Zhao committed
750
751
752
                    memory_fence();
                }

lijian6's avatar
lijian6 committed
753
                // 更新尾部位置
lijian6's avatar
lijian6 committed
754
                syncwarp();
lijian6's avatar
lijian6 committed
755
                if(lane_id == dst_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
756
757
                    last_issued_tail += num_tokens_to_issue;
                    num_tokens_to_send -= num_tokens_to_issue;
lijian6's avatar
lijian6 committed
758
                    // 更新远端rdma 己方已发送的token数,用于做发送信息同步。用于与kRDMAAndNVLForwarder互相通信
lijian6's avatar
lijian6 committed
759
760
761
                    rocshmem::rocshmem_ctx_ulong_atomic_add(
                        ctx, rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
                        translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
Chenggang Zhao's avatar
Chenggang Zhao committed
762
763
                }
            }
lijian6's avatar
lijian6 committed
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        } // while(__any(num_tokens_to_send > 0))
    } else if(warp_role == WarpRole::kRDMAAndNVLForwarder) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调从RDMA消费者到NVL生产者的转发操作。
        它首先计算目标NVL秩和目标秩,然后等待相关的计数器到达。
        接着,它检查目标队列是否为空,或者等待一个缓冲区被释放。
        然后,它找到下一个源RDMA秩,并遍历RDMA缓冲区中的每一个令牌,复制相关的数据到NVL缓冲区。
        最后,它同步头部和尾部索引,并标记通道为退役状态。
        */
        // RDMA消费者和NVL生产者
        const auto dst_nvl_rank          = target_rank;                                       // 目标NVL秩
        const auto dst_rank              = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank;      // 目标秩
        const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks);              // 目标秩专家开始
        const auto dst_rank_expert_end   = dst_rank_expert_begin + (num_experts / num_ranks); // 目标秩专家结束

        // 等待计数器到达
Chenggang Zhao's avatar
Chenggang Zhao committed
780
        int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0;
lijian6's avatar
lijian6 committed
781
782
        EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
783
784
785
786
787
788
789
790
791
        if(lane_id < kNumRDMARanks) {
            while(true) {
                // 对应于kRDMASender中的数据写入
                auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank);                     // 是nvl节点的起始地址
                auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); // nvl节点的结束地址
                auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2);            // 本rdma节点的起始地址
                auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1);        // 本节点的结束地址
                if(meta_0 < 0 && meta_1 < 0 && meta_2 < 0 && meta_3 < 0) {
                    // 通知NVL秩
Chenggang Zhao's avatar
Chenggang Zhao committed
792
                    int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
lijian6's avatar
lijian6 committed
793
794
795
                    EP_DEVICE_ASSERT(start_sum >= 0 && end_sum >= 0 && end_sum >= start_sum);

                    st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
796
797
                    st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1);

lijian6's avatar
lijian6 committed
798
799
                    // 保存从RDMA通道接收的令牌计数
                    src_rdma_channel_prefix = -meta_2 - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
800
                    auto src_rdma_channel_prefix_1 = -meta_3 - 1;
lijian6's avatar
lijian6 committed
801
802
803
804
805
                    num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; // 是远端 rdma_rank 会发送给当前节点的token数量
                    if(!kCachedMode)
                        recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1;

                    src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; // 对应的远端 rdma_rank 的起始index, 存在线程0之中
Chenggang Zhao's avatar
Chenggang Zhao committed
806
807
808
809
                    EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0);
                    break;
                }

lijian6's avatar
lijian6 committed
810
811
812
813
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n",
                           channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3);
Chenggang Zhao's avatar
Chenggang Zhao committed
814
815
816
817
                    trap();
                }
            }
        }
lijian6's avatar
lijian6 committed
818
        syncwarp();
lijian6's avatar
lijian6 committed
819
820

        // 移动缓存的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
821
822
        send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank;

lijian6's avatar
lijian6 committed
823
824
825
        // 等待共享内存被清理
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
826

lijian6's avatar
lijian6 committed
827
828
829
830
        // 开始准备处理接受数据,直到所有的数据接受完成。
        // 转发从RDMA缓冲区的令牌
        // 注意:总是从本地秩开始
        int src_rdma_rank = sm_id % kNumRDMARanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
831
832
        int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
        int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0;
lijian6's avatar
lijian6 committed
833
834
        while(__any_sync(kFullWarpMask, num_tokens_to_recv_from_rdma > 0)) {
            // 检查nvl目标队列是否为空,或者等待一个缓冲区被释放
lijian6's avatar
lijian6 committed
835
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
836
837
838

            // 用于给kNVLReceivers进行互动,控制数据的传输速度
            while(lane_id == 0) {
lijian6's avatar
lijian6 committed
839
                int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head;
lijian6's avatar
lijian6 committed
840
                if(num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
841
                    break;
lijian6's avatar
lijian6 committed
842
                cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
843

lijian6's avatar
lijian6 committed
844
845
846
847
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
848
849
850
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
851
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
852

lijian6's avatar
lijian6 committed
853
            // 找到下一个源RDMA秩(轮询)
lijian6's avatar
lijian6 committed
854
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
855
            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
856
                src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks;
lijian6's avatar
lijian6 committed
857
858
859
860
861
                if(shfl_sync(num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) {
                    if(lane_id == src_rdma_rank && cached_rdma_channel_head == cached_rdma_channel_tail)
                        cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));

                    if(shfl_sync(cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
862
                        break;
lijian6's avatar
lijian6 committed
863
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
864
865
                }

lijian6's avatar
lijian6 committed
866
867
868
869
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n",
                           channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
Chenggang Zhao's avatar
Chenggang Zhao committed
870
871
872
                    trap();
                }
            }
lijian6's avatar
lijian6 committed
873

lijian6's avatar
lijian6 committed
874
875
            auto src_rdma_head = shfl_sync(cached_rdma_channel_head, src_rdma_rank);
            auto src_rdma_tail = shfl_sync(cached_rdma_channel_tail, src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
876

lijian6's avatar
lijian6 committed
877
878
879
880
881
882
883
884
885
886
            // 遍历RDMA缓冲区中的每一个令牌
            for(int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) {
                auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
                // 首先读取SourceMeta,对应到kRDMASenderCoordinator中 kRDMASender 的数据远程写入
                void* shifted           = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
                auto src_meta           = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
                if(lane_id == src_rdma_rank) {
                    num_tokens_to_recv_from_rdma -= 1;
                }

Chenggang Zhao's avatar
Chenggang Zhao committed
887
                bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
lijian6's avatar
lijian6 committed
888
                if(lane_id == src_rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
889
890
                    auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1;
                    rdma_nvl_token_idx += is_in_dst_nvl_rank;
lijian6's avatar
lijian6 committed
891
                    if(!kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
892
893
                        send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head;
                }
lijian6's avatar
lijian6 committed
894
895

                if(!is_in_dst_nvl_rank)
Chenggang Zhao's avatar
Chenggang Zhao committed
896
897
                    continue;

lijian6's avatar
lijian6 committed
898
                // 获取一个空闲槽位
lijian6's avatar
lijian6 committed
899
                int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
900

lijian6's avatar
lijian6 committed
901
                // 复制数据
lijian6's avatar
lijian6 committed
902
903
                UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
                                   nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
lijian6's avatar
lijian6 committed
904
905
906
                                   reinterpret_cast<int4*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
lijian6's avatar
lijian6 committed
907

lijian6's avatar
lijian6 committed
908
909
                // 复制源元数据
                if(lane_id == 0)
lijian6's avatar
lijian6 committed
910
                    st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
lijian6's avatar
lijian6 committed
911
                shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
lijian6's avatar
lijian6 committed
912

lijian6's avatar
lijian6 committed
913
                // 复制 `x_scales`
lijian6's avatar
lijian6 committed
914
915
                UNROLLED_WARP_COPY(1, lane_id, num_scales,
                                   nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
lijian6's avatar
lijian6 committed
916
917
918
919
920
921
922
923
924
925
926
927
928
929
                                   reinterpret_cast<float*>(shifted),
                                   ld_nc_global, st_na_global);
                shifted = reinterpret_cast<float*>(shifted) + num_scales;

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
                    // 读取
                    auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
                    shifted = reinterpret_cast<int*>(shifted) + num_topk;
                    auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);

                    // 转换和写入
                    idx_value = (idx_value >= dst_rank_expert_begin && idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
                    st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
lijian6's avatar
lijian6 committed
930
                    weight_value = idx_value >= 0 ? weight_value : 0.0f;
lijian6's avatar
lijian6 committed
931
                    st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
Chenggang Zhao's avatar
Chenggang Zhao committed
932
933
                }

lijian6's avatar
lijian6 committed
934
935
                // 在NVL缓冲区不足的情况下,提前停止
                if((++num_tokens_sent) == num_max_nvl_chunked_send_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
936
937
938
                    src_rdma_tail = i + 1;
            }

lijian6's avatar
lijian6 committed
939
940
941
            // 同步头部索引
            if(lane_id == src_rdma_rank)
                forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail);
Chenggang Zhao's avatar
Chenggang Zhao committed
942

lijian6's avatar
lijian6 committed
943
            // 移动尾部索引,与kNVLReceivers互相通信使用
lijian6's avatar
lijian6 committed
944
            syncwarp();
lijian6's avatar
lijian6 committed
945
946
947
            if(lane_id == 0) {
                st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
948
949
950
        }

        // Retired
lijian6's avatar
lijian6 committed
951
        syncwarp();
lijian6's avatar
lijian6 committed
952
        if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
953
            forward_channel_retired[dst_nvl_rank] = true;
lijian6's avatar
lijian6 committed
954
955
956
957
958
959
960
961
962
        }
    } else if(warp_role == WarpRole::kForwarderCoordinator) {
        /*
        这段代码的主要功能是在一个CUDA内核中协调转发器的逻辑。
        它首先检查当前warp是否是额外的转发器协调warp,如果是,则直接退出。
        然后,它清理共享内存,并初始化转发通道的头部和退役状态。
        接着,它进入一个无限循环,在循环中,它找到最小的头部,如果所有的通道都已退役,则退出循环。
        否则,它更新远程头部,并进行纳秒级睡眠,以让其他warp工作。
        */
Chenggang Zhao's avatar
Chenggang Zhao committed
963
        // Extra warps for forwarder coordinator should exit directly
lijian6's avatar
lijian6 committed
964
        if (warp_id > NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
965
966
            return;

lijian6's avatar
lijian6 committed
967
968
969
970
971
972
        // 转发warp协调器
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        // 清理共享内存
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "无效的NVL对等体数量");
#pragma unroll
        for(int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
973
            forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0;
lijian6's avatar
lijian6 committed
974
        if(lane_id < NUM_MAX_NVL_PEERS)
Chenggang Zhao's avatar
Chenggang Zhao committed
975
            forward_channel_retired[lane_id] = false;
lijian6's avatar
lijian6 committed
976
977
        // sync_forwarder_smem();
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
978
979

        int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0;
lijian6's avatar
lijian6 committed
980
981
982

        while(true) {
            // 找到最小的头部
Chenggang Zhao's avatar
Chenggang Zhao committed
983
            int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
984
#pragma unroll
lijian6's avatar
lijian6 committed
985
986
            for(int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
                if(!forward_channel_retired[i])
lijian6's avatar
lijian6 committed
987
                    min_head = min(min_head, forward_channel_head[i][target_rdma]);
lijian6's avatar
lijian6 committed
988
989

            if(__all_sync(kFullWarpMask, min_head == std::numeric_limits<int>::max())) {
Chenggang Zhao's avatar
Chenggang Zhao committed
990
                break;
lijian6's avatar
lijian6 committed
991
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
992

lijian6's avatar
lijian6 committed
993
994
            // 更新远程头部
            if(min_head != std::numeric_limits<int>::max() && min_head >= last_head + num_max_rdma_chunked_send_tokens && lane_id < kNumRDMARanks){
lijian6's avatar
lijian6 committed
995
996
997
                rocshmem::rocshmem_ctx_ulong_atomic_add(
                    ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_head,
                    translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
998
999
                last_head = min_head;
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1000

lijian6's avatar
lijian6 committed
1001
            // 纳秒级睡眠并让其他warp工作 // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1002
            __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1003
        }
lijian6's avatar
lijian6 committed
1004
1005
1006
1007
1008
1009
1010
1011
    } else if(warp_role == WarpRole::kNVLReceivers) {
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }

        // Place the main logic of your kernel here, using the parameters above.
        // NVL消费者
        // 从屏障结果中检索秩偏移(每个通道的寄存器存储一个RDMA秩)
Chenggang Zhao's avatar
Chenggang Zhao committed
1012
        int src_nvl_rank = target_rank, total_offset = 0;
lijian6's avatar
lijian6 committed
1013
1014
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "无效的RDMA对等体数量");
        if(lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1015
1016
            total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1];

lijian6's avatar
lijian6 committed
1017
1018
        // 接收通道偏移
        int start_offset = 0, end_offset = 0, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
1019
        auto start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1020
1021

        while(lane_id < kNumRDMARanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1022
            start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1023
            end_offset   = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id);
lijian6's avatar
lijian6 committed
1024
            if(start_offset < 0 && end_offset < 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1025
1026
1027
1028
                start_offset = -start_offset - 1, end_offset = -end_offset - 1;
                total_offset += start_offset;
                break;
            }
lijian6's avatar
lijian6 committed
1029
1030
1031
1032
            // 超时检查
            if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n",
                        channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
1033
1034
1035
                trap();
            }
        }
lijian6's avatar
lijian6 committed
1036

Chenggang Zhao's avatar
Chenggang Zhao committed
1037
1038
        num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset);

lijian6's avatar
lijian6 committed
1039
1040
1041
        // 保存以供合并使用
        if(lane_id < kNumRDMARanks && !kCachedMode)
            recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset;
lijian6's avatar
lijian6 committed
1042
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1043
1044

        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1045
1046
        while(num_tokens_to_recv > 0) {
            // 通过通道0检查通道状态
lijian6's avatar
lijian6 committed
1047
            start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1048
1049
1050
            while(lane_id == 0) {
                // 准备复制
                if(cached_channel_head_idx != cached_channel_tail_idx)
Chenggang Zhao's avatar
Chenggang Zhao committed
1051
                    break;
lijian6's avatar
lijian6 committed
1052
1053
1054
1055
1056
                cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer());
                // 超时检查
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
                            channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1057
1058
1059
1060
                    trap();
                }
            }

lijian6's avatar
lijian6 committed
1061
            // 同步队列尾部
lijian6's avatar
lijian6 committed
1062
1063
            cached_channel_tail_idx = shfl_sync(cached_channel_tail_idx, 0);

lijian6's avatar
lijian6 committed
1064
            // 复制数据
Chenggang Zhao's avatar
Chenggang Zhao committed
1065
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1066
1067
1068
1069
            for(int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) {
                int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens;
                auto meta               = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer);
                int64_t recv_token_idx  = shfl_sync(total_offset, meta.src_rdma_rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
1070
1071
                (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0;

lijian6's avatar
lijian6 committed
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
                // 复制数据
                UNROLLED_WARP_COPY(5,
                                lane_id,
                                hidden_int4,
                                recv_x + recv_token_idx * hidden_int4,
                                nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4,
                                ld_nc_global,
                                st_na_global);

                // 复制源元数据
                if(lane_id == 0 && !kCachedMode)
Chenggang Zhao's avatar
Chenggang Zhao committed
1083
                    st_na_global(recv_src_meta + recv_token_idx, meta);
lijian6's avatar
lijian6 committed
1084

lijian6's avatar
lijian6 committed
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
                // 复制比例
                UNROLLED_WARP_COPY(1,
                                lane_id,
                                num_scales,
                                recv_x_scales + recv_token_idx * num_scales,
                                nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales,
                                ld_nc_global,
                                st_na_global);

                // 复制 `topk_idx` 和 `topk_weights`
                if(lane_id < num_topk) {
lijian6's avatar
lijian6 committed
1096
1097
                    auto recv_idx   = recv_token_idx * num_topk + lane_id;
                    auto buffer_idx = token_idx_in_buffer * num_topk + lane_id;
lijian6's avatar
lijian6 committed
1098
1099
                    st_na_global(recv_topk_idx + recv_idx, static_cast<int64_t>(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx)));
                    st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1100
1101
1102
                }
            }

lijian6's avatar
lijian6 committed
1103
            // 移动队列
lijian6's avatar
lijian6 committed
1104
            syncwarp();
lijian6's avatar
lijian6 committed
1105
            if(lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1106
                st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
lijian6's avatar
lijian6 committed
1107
1108
            }
        } // while(num_tokens_to_recv > 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
1109
    }
lijian6's avatar
lijian6 committed
1110
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1111
1112
}

lijian6's avatar
lijian6 committed
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
void dispatch(void *recv_x, float *recv_x_scales, int64_t *recv_topk_idx, float *recv_topk_weights,
              void *recv_src_meta, const void *x, const float *x_scales, const int64_t *topk_idx,
              const float *topk_weights, int *send_rdma_head, int *send_nvl_head,
              int *recv_rdma_channel_prefix_matrix, int *recv_gbl_channel_prefix_matrix,
              const int *rdma_channel_prefix_matrix, const int *recv_rdma_rank_prefix_sum,
              const int *gbl_channel_prefix_matrix, const int *recv_gbl_rank_prefix_sum,
              const bool *is_token_in_rank, int num_tokens, int hidden_int4, int num_scales,
              int num_topk, int num_experts, int scale_token_stride, int scale_hidden_stride,
              void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
              int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
              int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
              int num_ranks, bool is_cached_dispatch, hipStream_t stream, int num_channels,
              bool low_latency_mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1126
    constexpr int kNumDispatchRDMASenderWarps = 7;
lijian6's avatar
lijian6 committed
1127
1128
    // Make sure never OOB
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride < std::numeric_limits<int>::max());
lijian6's avatar
lijian6 committed
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

#define DISPATCH_LAUNCH_CASE(num_rdma_ranks)                                                       \
    {                                                                                              \
        auto dispatch_func =                                                                       \
            low_latency_mode                                                                       \
                ? (is_cached_dispatch                                                              \
                       ? dispatch<true, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>         \
                       : dispatch<true, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>)       \
                : (is_cached_dispatch                                                              \
                       ? dispatch<false, num_rdma_ranks, true, kNumDispatchRDMASenderWarps>        \
                       : dispatch<false, num_rdma_ranks, false, kNumDispatchRDMASenderWarps>);     \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch_func, reinterpret_cast<int4 *>(recv_x), recv_x_scales, recv_topk_idx,   \
            recv_topk_weights, reinterpret_cast<SourceMeta *>(recv_src_meta),                      \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, send_rdma_head,   \
            send_nvl_head, recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix,        \
            rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix,      \
            recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, num_scales,       \
            num_topk, num_experts, scale_token_stride, scale_hidden_stride, rdma_buffer_ptr,       \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

    EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
1154
1155
    EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

lijian6's avatar
lijian6 committed
1156
1157
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
                        (1 + NUM_MAX_NVL_PEERS) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1158
1159
1160
1161
    SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
1162
template <bool kLowLatencyMode>
lijian6's avatar
lijian6 committed
1163
__global__ void __launch_bounds__(1024, 1)
lijian6's avatar
lijian6 committed
1164
1165
1166
1167
1168
1169
1170
1171
cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, const int nvl_clean_offset,
              const int nvl_num_int_clean, int *combined_rdma_head, int num_combined_tokens,
              int num_channels, const int *rdma_channel_prefix_matrix,
              const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
              void **buffer_ptrs, int **barrier_signal_ptrs, int rank, int num_ranks,
              bool is_cached_dispatch, const rocshmem::rocshmem_team_t rdma_team) {
    auto sm_id       = static_cast<int>(blockIdx.x);
    auto thread_id   = static_cast<int>(threadIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
1172
    auto num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
1173
1174
1175
    auto num_warps   = num_threads / kWarpSize;
    auto warp_id     = thread_id / kWarpSize;
    auto lane_id     = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
1176

lijian6's avatar
lijian6 committed
1177
    auto nvl_rank       = rank % NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1178
1179
1180
1181
1182
    auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Using two SMs, which clean the RDMA/NVL buffer respectively
    if (sm_id == 0) {
        // Barrier for RDMA
lijian6's avatar
lijian6 committed
1183
1184
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Shangyan Zhou's avatar
Fix  
Shangyan Zhou committed
1185

lijian6's avatar
lijian6 committed
1186
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1187

lijian6's avatar
lijian6 committed
1188
1189
        // Clean
        auto rdma_buffer_ptr_int = reinterpret_cast<int *>(rdma_buffer_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
1190
1191
        for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
            rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1192
1193
1194
1195
1196
1197
        rocshmem::rocshmem_fence();
        __syncthreads();

        // Barrier again
        if (thread_id == 0)
            nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1198

lijian6's avatar
lijian6 committed
1199
1200
1201
1202
1203
1204
1205
    } else if (sm_id == 1) {
        // Barrier for NVL
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
        __syncthreads();

        // Clean
        auto nvl_buffer_ptr_int = reinterpret_cast<int *>(buffer_ptrs[nvl_rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
1206
1207
        for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
            nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
lijian6's avatar
lijian6 committed
1208
        memory_fence();
1209
        __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1210
1211

        // Barrier again
1212
        barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
lijian6's avatar
lijian6 committed
1213
    } else if (sm_id == 2) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1214
1215
1216
1217
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1218
        EP_DEVICE_ASSERT(num_rdma_ranks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
1219
1220
1221
1222

        // Iterate in reverse order
        if (lane_id < num_rdma_ranks and warp_id < num_channels) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1223
1224
            get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1225
1226
1227

            // NOTES: `1 << 25` is a heuristic large number
            int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1228
1229
1230
            for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                auto current_head =
                    __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
                if (current_head < 0) {
                    combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
                } else {
                    last_head = current_head;
                }
            }
        }
    } else {
        if (is_cached_dispatch)
            return;

        EP_DEVICE_ASSERT(num_warps >= num_channels);
lijian6's avatar
lijian6 committed
1243
1244
1245
        EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and
                                  rdma_rank_prefix_sum != nullptr);
        EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Too many NVL peers");
1246

lijian6's avatar
lijian6 committed
1247
1248
1249
        if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) {
            for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks;
                 dst_rdma_rank += num_channels * 2 - 3) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1250
                // Iterate in reverse order
lijian6's avatar
lijian6 committed
1251
1252
1253
1254
1255
1256
                int token_start_idx =
                    warp_id == 0
                        ? 0
                        : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
                int token_end_idx =
                    rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
1257
1258
1259
1260
1261
                int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
                token_start_idx += shift, token_end_idx += shift;

                // NOTES: `1 << 25` is a heuristic large number
                int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
1262
1263
1264
1265
1266
1267
1268
                for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; --token_idx) {
                    auto current_head =
                        __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
                    if (current_head < 0) {
                        combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1;
                    } else {
                        last_head = current_head;
1269
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1270
1271
1272
1273
1274
1275
1276
                }
            }
        }
    }
}

void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights,
lijian6's avatar
lijian6 committed
1277
1278
1279
1280
1281
1282
                   int num_ranks, int num_channels, int num_combined_tokens,
                   int *combined_rdma_head, const int *rdma_channel_prefix_matrix,
                   const int *rdma_rank_prefix_sum, int *combined_nvl_head, void *rdma_buffer_ptr,
                   int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
                   int num_max_nvl_chunked_recv_tokens, int **barrier_signal_ptrs, int rank,
                   hipStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
Chenggang Zhao's avatar
Chenggang Zhao committed
1283
                   bool is_cached_dispatch, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1284
    const int  num_threads    = ::max(128, kWarpSize * num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1285
1286
1287
    const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;

    // Get clean meta
lijian6's avatar
lijian6 committed
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    auto rdma_clean_meta =
        get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                            num_max_rdma_chunked_recv_tokens, num_channels);
    auto nvl_clean_meta =
        get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks,
                           NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels);
    EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <=
                       num_rdma_bytes);
    EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <=
                       num_nvl_bytes);
Chenggang Zhao's avatar
Chenggang Zhao committed
1298
1299
1300
1301
1302
    EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits<int>::max());
    EP_HOST_ASSERT(num_channels * 2 > 3);

    // Launch kernel
lijian6's avatar
lijian6 committed
1303
    auto cached_notify_func = low_latency_mode ? cached_notify<true> : cached_notify<false>;
Chenggang Zhao's avatar
Chenggang Zhao committed
1304
    SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream);
lijian6's avatar
lijian6 committed
1305
1306
1307
1308
1309
1310
    LAUNCH_KERNEL_NON_COOPERATIVE(
        &cfg, cached_notify_func, rdma_clean_meta.first, rdma_clean_meta.second,
        nvl_clean_meta.first, nvl_clean_meta.second, combined_rdma_head, num_combined_tokens,
        num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
        rdma_buffer_ptr, buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch,
        cpu_rdma_team);
Chenggang Zhao's avatar
Chenggang Zhao committed
1311
1312
}

lijian6's avatar
lijian6 committed
1313
1314
1315
1316
1317
template <int kNumRanks, typename dtype_t, int kMaxNumRanks, typename ReceiveFn, typename ReceiveTWFn>
__device__ int combine_token(bool is_token_in_rank, int head_idx,
                             int lane_id, int hidden_int4, int num_topk,
                             int4* combined_row, float* combined_topk_weights,
                             int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1318
1319
1320
1321
    constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);

    // Broadcast current heads
    // Lane `i` holds the head of rank `i` and `is_token_in_rank`
lijian6's avatar
lijian6 committed
1322
    EP_STATIC_ASSERT(kMaxNumRanks <= kWarpSize, "Too many ranks");
Chenggang Zhao's avatar
Chenggang Zhao committed
1323
    int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks];
lijian6's avatar
lijian6 committed
1324
1325
1326
1327
1328
    #pragma unroll
    for (int i = 0; i < kNumRanks; ++ i) if (shfl_sync(is_token_in_rank, i)) {
        slot_indices[num_topk_ranks] = shfl_sync(head_idx, i) % num_max_recv_tokens;
        topk_ranks[num_topk_ranks ++] = i;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1329
1330
1331
    EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks);

    // Reduce data
lijian6's avatar
lijian6 committed
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
    #pragma unroll
    for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
        // Read buffers
        // TODO: maybe too many registers here
        int4 recv_value_int4[kMaxNumRanks];
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j)
            recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i);

        // Reduce all-to-all results
lijian6's avatar
lijian6 committed
1342
        float values[kDtypePerInt4] = {0};
lijian6's avatar
lijian6 committed
1343
1344
1345
1346
1347
1348
        #pragma unroll
        for (int j = 0; j < num_topk_ranks; ++ j) {
            auto recv_value_dtypes = reinterpret_cast<const dtype_t*>(&recv_value_int4[j]);
            #pragma unroll
            for (int k = 0; k < kDtypePerInt4; ++ k)
                values[k] += static_cast<float>(recv_value_dtypes[k]);
Shangyan Zhou's avatar
Shangyan Zhou committed
1349
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1350

lijian6's avatar
lijian6 committed
1351
1352
1353
1354
1355
        // Cast back to `dtype_t` and write
        int4 out_int4;
        auto out_dtypes = reinterpret_cast<dtype_t*>(&out_int4);
        #pragma unroll
        for (int j = 0; j < kDtypePerInt4; ++ j)
lijian6's avatar
lijian6 committed
1356
1357
            out_dtypes[j] = static_cast<dtype_t>(values[j]);
        st_na_global(combined_row + i, out_int4);
Chenggang Zhao's avatar
Chenggang Zhao committed
1358
1359
1360
1361
1362
    }

    // Reduce `topk_weights`
    if (lane_id < num_topk) {
        float value = 0;
lijian6's avatar
lijian6 committed
1363
1364
        #pragma unroll
        for (int i = 0; i < num_topk_ranks; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1365
1366
1367
1368
1369
1370
1371
1372
            value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id);
        st_na_global(combined_topk_weights + lane_id, value);
    }

    // Return the minimum top-k rank
    return topk_ranks[0];
}

lijian6's avatar
lijian6 committed
1373
1374
1375
1376
template <bool kLowLatencyMode,
          int kNumRDMARanks,
          typename dtype_t,
          int kNumCombineForwarderWarps,
lijian6's avatar
lijian6 committed
1377
          int kNumTopkRDMARanks     = get_num_topk_rdma_ranks(kNumRDMARanks),
lijian6's avatar
lijian6 committed
1378
          int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1,
lijian6's avatar
lijian6 committed
1379
          int kNumForwarders        = kNumRDMARanks * kNumWarpsPerForwarder,
lijian6's avatar
lijian6 committed
1380
1381
1382
          int kNumRDMAReceivers     = kNumForwarders>
__global__ void __launch_bounds__((1 + NUM_MAX_NVL_PEERS) * kWarpSize, 1) 
combine(int4 *combined_x, float *combined_topk_weights, const bool *is_combined_token_in_rank,
lijian6's avatar
lijian6 committed
1383
1384
1385
1386
1387
1388
1389
1390
            const int4 *x, const float *topk_weights, const int4 *bias_0, const int4 *bias_1,
            const int *combined_rdma_head, const int *combined_nvl_head, const SourceMeta *src_meta,
            const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
            const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
            int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
            int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
            int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
            int num_ranks) {
lijian6's avatar
lijian6 committed
1391
1392
1393
1394
1395
1396
1397
    enum class WarpRole {
        kNVLSender,
        kNVLAndRDMAForwarder,
        kRDMAReceiver,
        kRDMACoordinator,
        kNVLCoordinator
    };
Chenggang Zhao's avatar
Chenggang Zhao committed
1398

lijian6's avatar
lijian6 committed
1399
1400
1401
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);

lijian6's avatar
lijian6 committed
1402
1403
1404
1405
1406
1407
    const auto sm_id       = static_cast<int>(blockIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / kWarpSize;
    const auto thread_id   = static_cast<int>(threadIdx.x), warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_channels = static_cast<int>(gridDim.x) / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL,
               channel_id   = sm_id / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;

Chenggang Zhao's avatar
Chenggang Zhao committed
1408
1409
1410
1411
    const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t));

    // NOTES: we decouple a channel into 2 SMs
    const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
lijian6's avatar
lijian6 committed
1412
1413
1414
1415
1416
1417
1418

    const auto role_meta = [=]() -> std::pair<WarpRole, int> {
        if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 1) {
            return {WarpRole::kNVLSender, (warp_id + channel_id) % NUM_MAX_NVL_PEERS};
        } else if (sm_id % NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL == 0) {
            if(warp_id < kNumForwarders) {
                return {WarpRole::kNVLAndRDMAForwarder, (warp_id + channel_id) % kNumForwarders};
Chenggang Zhao's avatar
Chenggang Zhao committed
1419
            } else {
lijian6's avatar
lijian6 committed
1420
                return {WarpRole::kRDMACoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1421
1422
            }
        } else {
lijian6's avatar
lijian6 committed
1423
            if(warp_id < kNumForwarders) {
lijian6's avatar
lijian6 committed
1424
                return {WarpRole::kRDMAReceiver, warp_id};
Chenggang Zhao's avatar
Chenggang Zhao committed
1425
            } else {
lijian6's avatar
lijian6 committed
1426
                return {WarpRole::kNVLCoordinator, 0};
Chenggang Zhao's avatar
Chenggang Zhao committed
1427
1428
1429
            }
        }
    }();
lijian6's avatar
lijian6 committed
1430

Chenggang Zhao's avatar
Chenggang Zhao committed
1431
    auto warp_role = role_meta.first;
lijian6's avatar
lijian6 committed
1432
    auto target_rank = role_meta.second; // Not applicable for RDMA senders
Chenggang Zhao's avatar
Chenggang Zhao committed
1433

lijian6's avatar
lijian6 committed
1434
    EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
1435
1436
    auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1437
    // This approach is designed to sync multiple warps in a loop
lijian6's avatar
lijian6 committed
1438
1439
1440
1441
    constexpr int num_sync_large_iteration = 64;
    constexpr int rdma_warp_counters = kNumRDMARanks * num_sync_large_iteration;
    __shared__ volatile int sync_large_warp_counters[2 * rdma_warp_counters];   
    for (int i = thread_id; i < 2 * rdma_warp_counters; i += num_threads) {
lijian6's avatar
lijian6 committed
1442
1443
1444
1445
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

Chenggang Zhao's avatar
Chenggang Zhao committed
1446
    if (warp_role == WarpRole::kNVLSender) {
lijian6's avatar
lijian6 committed
1447
1448
1449
        if(warp_id >= NUM_MAX_NVL_PEERS) {
            return;
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
1450

lijian6's avatar
lijian6 committed
1451
        const auto dst_nvl_rank = target_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
1452
1453
1454
        // NVL layouts
        // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources
        auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank];
lijian6's avatar
lijian6 committed
1455
1456
1457
1458
1459
        auto nvl_channel_x = AsymBuffer<int4>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_topk_weights = AsymBuffer<float>(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
        auto nvl_channel_head = AsymBuffer<int>(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr);
        auto nvl_channel_tail = AsymBuffer<int>(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr);
1460

Chenggang Zhao's avatar
Chenggang Zhao committed
1461
1462
        // Get tasks for each RDMA lane
        int token_start_idx = 0, token_end_idx = 0;
lijian6's avatar
lijian6 committed
1463
1464
        if(lane_id < kNumRDMARanks) {
            int prefix_idx  = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id;
Chenggang Zhao's avatar
Chenggang Zhao committed
1465
            token_start_idx = gbl_channel_prefix_matrix[prefix_idx];
lijian6's avatar
lijian6 committed
1466
            token_end_idx   = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1467
        }
lijian6's avatar
lijian6 committed
1468
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1469
1470
1471

        // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer
        int cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1472
        EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1473
1474

        // Iterate over all tokens and send by chunks
lijian6's avatar
lijian6 committed
1475
        while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1476
            // Exit if possible
lijian6's avatar
lijian6 committed
1477
            if(__all_sync(kFullWarpMask, token_start_idx >= token_end_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1478
1479
                break;

lijian6's avatar
lijian6 committed
1480
            // Decide next RDMA buffer to send
Chenggang Zhao's avatar
Chenggang Zhao committed
1481
            bool is_lane_ready = false;
lijian6's avatar
lijian6 committed
1482
            auto start_time    = wall_clock64();
lijian6's avatar
lijian6 committed
1483
1484

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1485
                int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
1486
                is_lane_ready      = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and
lijian6's avatar
lijian6 committed
1487
1488
1489
                                num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens;

                if(__any_sync(kFullWarpMask, is_lane_ready))
Chenggang Zhao's avatar
Chenggang Zhao committed
1490
                    break;
lijian6's avatar
lijian6 committed
1491

Chenggang Zhao's avatar
Chenggang Zhao committed
1492
                // Retry
lijian6's avatar
lijian6 committed
1493
1494
                if(lane_id < kNumRDMARanks and token_start_idx < token_end_idx)
                    cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1495
1496

                // Timeout check
lijian6's avatar
lijian6 committed
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
                if(wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
                    printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, "
                        "RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n",
                        channel_id,
                        rdma_rank,
                        nvl_rank,
                        dst_nvl_rank,
                        lane_id,
                        ld_volatile_global(nvl_channel_head.buffer() + lane_id),
                        cached_channel_tail_idx,
                        token_start_idx,
                        token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1509
1510
1511
1512
1513
                    trap();
                }
            }

            // Sync token start index and count
lijian6's avatar
lijian6 committed
1514
1515
            for(int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++current_rdma_idx) {
                if(shfl_sync((token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx))
Chenggang Zhao's avatar
Chenggang Zhao committed
1516
1517
1518
                    continue;

                // Sync token start index
lijian6's avatar
lijian6 committed
1519
1520
                auto token_idx          = static_cast<int64_t>(shfl_sync(token_start_idx, current_rdma_idx));
                int num_tokens_in_chunk = shfl_sync(min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1521
1522

                // Send by chunk
lijian6's avatar
lijian6 committed
1523
                for(int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++chunk_idx, ++token_idx) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1524
1525
                    // Get an empty slot
                    int dst_slot_idx = 0;
lijian6's avatar
lijian6 committed
1526
1527
1528
                    if(lane_id == current_rdma_idx) {
                        dst_slot_idx = (cached_channel_tail_idx++) % num_max_nvl_chunked_recv_tokens_per_rdma;
                        dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
1529
                    }
lijian6's avatar
lijian6 committed
1530
                    dst_slot_idx = shfl_sync(dst_slot_idx, current_rdma_idx);
lijian6's avatar
lijian6 committed
1531
1532
1533
1534

                    // Copy data
                    auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4;
                    auto shifted_x         = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
1535
                    UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
1536

lijian6's avatar
lijian6 committed
1537
                    // Copy source meta
lijian6's avatar
lijian6 committed
1538
1539
                    if(lane_id == 0)
                        st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx));
Chenggang Zhao's avatar
Chenggang Zhao committed
1540

lijian6's avatar
lijian6 committed
1541
                    // Copy `topk_weights`
lijian6's avatar
lijian6 committed
1542
1543
1544
                    if(lane_id < num_topk)
                        st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id,
                                    ld_nc_global(topk_weights + token_idx * num_topk + lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1545
1546
1547
1548
1549
                }
                lane_id == current_rdma_idx ? (token_start_idx = static_cast<int>(token_idx)) : 0;
            }

            // Move queue tail
lijian6's avatar
lijian6 committed
1550
            syncwarp();
lijian6's avatar
lijian6 committed
1551
1552
1553
            if(lane_id < kNumRDMARanks and is_lane_ready) {
                st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx);
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
1554
1555
        }
    } else {
lijian6's avatar
lijian6 committed
1556
1557
1558
1559
        if(warp_id > kNumForwarders) {
            return;
        }

Chenggang Zhao's avatar
Chenggang Zhao committed
1560
1561
        // Combiners and coordinators
        // RDMA symmetric layout
lijian6's avatar
lijian6 committed
1562
        auto hidden_bytes = hidden_int4 * sizeof(int4);
lijian6's avatar
lijian6 committed
1563
        auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk);
lijian6's avatar
lijian6 committed
1564
1565
1566
        auto rdma_channel_data = SymBuffer<int8_t>(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
        auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
Chenggang Zhao's avatar
Chenggang Zhao committed
1567
1568

        // NVL layouts
lijian6's avatar
lijian6 committed
1569
1570
1571
1572
        void* local_nvl_buffer = buffer_ptrs[nvl_rank];
        void* nvl_buffers[NUM_MAX_NVL_PEERS];
        #pragma unroll
        for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i)
Chenggang Zhao's avatar
Chenggang Zhao committed
1573
            nvl_buffers[i] = buffer_ptrs[i];
lijian6's avatar
lijian6 committed
1574
1575
1576
1577
1578
        auto nvl_channel_x = AsymBuffer<int4>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_src_meta = AsymBuffer<SourceMeta>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_topk_weights = AsymBuffer<float>(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
        auto nvl_channel_head = AsymBuffer<int, NUM_MAX_NVL_PEERS>(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer);
        auto nvl_channel_tail = AsymBuffer<int>(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also<NUM_MAX_NVL_PEERS>(nvl_buffers);
Chenggang Zhao's avatar
Chenggang Zhao committed
1579
1580

        // Combiner warp synchronization
lijian6's avatar
lijian6 committed
1581
        __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS];
Chenggang Zhao's avatar
Chenggang Zhao committed
1582
        __shared__ volatile bool forwarder_retired[kNumForwarders];
lijian6's avatar
lijian6 committed
1583
        __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
1584
        __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers];
lijian6's avatar
lijian6 committed
1585

Chenggang Zhao's avatar
Chenggang Zhao committed
1586
1587
1588
        if (warp_role == WarpRole::kNVLAndRDMAForwarder) {
            // Receive from NVL ranks and forward to RDMA ranks
            // NOTES: this part is using "large warps" for each RDMA ranks
lijian6's avatar
lijian6 committed
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
            const auto dst_rdma_rank = target_rank / kNumWarpsPerForwarder;
            const auto sub_warp_id   = target_rank % kNumWarpsPerForwarder;
            auto send_buffer         = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank);
            // auto sync_large_warp     = [=]() {
            //     if(kNumWarpsPerForwarder == 1) {
            //         syncwarp();
            //     } else {
            //         // asm volatile("bar.sync %0, %1;" ::"r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * kWarpSize));
            //         // __syncthreads();
            //         syncwarp();
            //     }
            // };
            auto sync_large_warp = [=](const int iter, const int mode) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1602
                if (kNumWarpsPerForwarder == 1) {
lijian6's avatar
lijian6 committed
1603
                    syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1604
                } else {
lijian6's avatar
lijian6 committed
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
                        // LDS index to store for sync
                        int lds_dst_rdma_rank = dst_rdma_rank + (iter % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        //reset index in the LDS to avoid race condition due to warp scheduling
                        int reset_idx =         dst_rdma_rank + ((iter + num_sync_large_iteration/2) % num_sync_large_iteration) * kNumRDMARanks + mode * rdma_warp_counters;
                        auto start_time = wall_clock64();
                        if (lane_id == 0){
                            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[lds_dst_rdma_rank], 1);
                        }
                        syncwarp();
                        //The while(...) loop polls the counter until all warps have arrived
                        if (lane_id == 0){
                            while (sync_large_warp_counters[lds_dst_rdma_rank] < (kNumWarpsPerForwarder)){
                                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                                    printf("DeepEP combine sync timeout. current num_sync_large_iteration %d. double it.\n", num_sync_large_iteration );
                                    trap();
                                }
lijian6's avatar
lijian6 committed
1621
1622
                            }
                        }
lijian6's avatar
lijian6 committed
1623
1624
1625
1626
1627
                        syncwarp();
                        if (lane_id == 0 && sync_large_warp_counters[reset_idx] == kNumWarpsPerForwarder){
                            sync_large_warp_counters[reset_idx] = 0;
                        }
                        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1628
1629
                }
            };
lijian6's avatar
lijian6 committed
1630
            EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= kNumCombineForwarderWarps, "Barriers are not enough");
1631

lijian6's avatar
lijian6 committed
1632
1633
            // Advance to the corresponding NVL buffer, 基于原本指针进行的地址偏移
            nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4);
lijian6's avatar
lijian6 committed
1634
            nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma);
lijian6's avatar
lijian6 committed
1635
            nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
1636
1637
1638
1639
            nvl_channel_head.advance(dst_rdma_rank);
            nvl_channel_tail.advance(dst_rdma_rank);

            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1640
1641
1642
1643
1644
            EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= kWarpSize, "Invalid number of NVL peers");
            lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (forwarder_retired[target_rank] = false) : false;
            // sync_forwarder_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1645
1646
1647

            // Get count and cached head
            int cached_nvl_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1648
1649
            int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
            int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
1650
1651
1652
1653
1654
            num_tokens_to_combine -= num_tokens_prefix;
            num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
            combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS;

            // Iterate over all tokens and combine by chunks
lijian6's avatar
lijian6 committed
1655
            for(int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1656
                // Check destination queue emptiness, or wait a buffer to be released
lijian6's avatar
lijian6 committed
1657
                auto token_end_idx      = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine);
Chenggang Zhao's avatar
Chenggang Zhao committed
1658
                auto num_chunked_tokens = token_end_idx - token_start_idx;
lijian6's avatar
lijian6 committed
1659
                auto start_time         = wall_clock64();
lijian6's avatar
lijian6 committed
1660
1661
1662
1663
1664
1665
                while(sub_warp_id == 0 and lane_id == 0) {
                    // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens`
                    // Here, `token_start_idx` is the actual tail
                    int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank));

                    if(num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens)
Chenggang Zhao's avatar
Chenggang Zhao committed
1666
1667
1668
                        break;

                    // Timeout check
lijian6's avatar
lijian6 committed
1669
1670
1671
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n",
                                channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens);
Chenggang Zhao's avatar
Chenggang Zhao committed
1672
1673
1674
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1675
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1676
                sync_large_warp(token_start_idx, 0);
lijian6's avatar
lijian6 committed
1677

Chenggang Zhao's avatar
Chenggang Zhao committed
1678
                // Combine and write to the RDMA buffer
lijian6's avatar
lijian6 committed
1679
                for(int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1680
                    // Read expected head
lijian6's avatar
lijian6 committed
1681
                    EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1682
                    int expected_head = -1;
lijian6's avatar
lijian6 committed
1683
1684
                    if(lane_id < NUM_MAX_NVL_PEERS)
                        expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
1685
1686

                    // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1687
                    start_time = wall_clock64();
lijian6's avatar
lijian6 committed
1688
1689
                    while(cached_nvl_channel_tail_idx <= expected_head) {
                        cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id));
Chenggang Zhao's avatar
Chenggang Zhao committed
1690
1691

                        // Timeout check
lijian6's avatar
lijian6 committed
1692
1693
1694
                        if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) {
                            printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n",
                                    channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1695
1696
1697
1698
1699
                            trap();
                        }
                    }

                    // Combine current token
lijian6's avatar
lijian6 committed
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
                    auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens;
                    void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token;
                    auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); };
                    auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); };
                    combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
                                                                                    expected_head, lane_id,
                                                                                    hidden_int4, num_topk,
                                                                                    reinterpret_cast<int4*>(shifted),
                                                                                    reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
                                                                                    num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1710
1711

                    // Update head
lijian6's avatar
lijian6 committed
1712
1713
1714
1715
                    if(lane_id < NUM_MAX_NVL_PEERS) {
                        expected_head < 0 ? (forwarder_nvl_head[target_rank][lane_id] = -expected_head - 1)
                                        : (forwarder_nvl_head[target_rank][lane_id] = expected_head + 1);
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1716
                }
lijian6's avatar
lijian6 committed
1717
                // sync_large_warp();
lijian6's avatar
lijian6 committed
1718
                sync_large_warp(token_start_idx, 1);
lijian6's avatar
lijian6 committed
1719

Chenggang Zhao's avatar
Chenggang Zhao committed
1720
                // Issue RDMA send
lijian6's avatar
lijian6 committed
1721
1722
                if(sub_warp_id == kNumWarpsPerForwarder - 1) {
                    if(dst_rdma_rank != rdma_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1723
                        auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
lijian6's avatar
lijian6 committed
1724
                        rocshmem::rocshmem_ctx_schar_put_nbi_wave(
lijian6's avatar
lijian6 committed
1725
1726
1727
1728
1729
1730
1731
1732
1733
                            ctx,
                            rdma_channel_data.recv_buffer(rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            rdma_channel_data.send_buffer(dst_rdma_rank) +
                                rdma_slot_idx * num_bytes_per_rdma_token,
                            num_chunked_tokens * num_bytes_per_rdma_token,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

                        rocshmem::rocshmem_ctx_quiet(ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1734
1735
1736
1737
1738
                    } else {
                        memory_fence();
                    }

                    // Write new RDMA tail
lijian6's avatar
lijian6 committed
1739
                    syncwarp();
lijian6's avatar
lijian6 committed
1740
                    if(lane_id == 0) {
lijian6's avatar
lijian6 committed
1741
1742
1743
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
lijian6's avatar
lijian6 committed
1744
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1745
1746
1747
1748
                }
            }

            // Retired
lijian6's avatar
lijian6 committed
1749
            syncwarp();
lijian6's avatar
lijian6 committed
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
            if(lane_id == 0) {
                forwarder_retired[target_rank] = true;
            }
        } else if (warp_role == WarpRole::kRDMACoordinator) {
            // Coordinator
            // Sync shared memory status
            // sync_forwarder_smem();
            __syncthreads();
            constexpr int num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

            int last_nvl_head[kNumRDMARanks] = {0};
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
                // Retired
                if(__all_sync(kFullWarpMask, lane_id >= kNumForwarders or forwarder_retired[lane_id]))
                    break;

                {
                    // Find minimum head for NVL ranks
                    #pragma unroll
                    for(int i = 0; i < kNumRDMARanks; ++i) {
                        int min_head = std::numeric_limits<int>::max();
                        #pragma unroll
                        for(int j = 0; j < num_warps_per_rdma_rank; ++j)
                            if(not forwarder_retired[i * num_warps_per_rdma_rank + j])
                                min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]);

                        if(min_head != std::numeric_limits<int>::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) {
                            st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head);
                        }
                    }
                }

                // Nanosleep and let other warps work
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
            }
        } else if(warp_role == WarpRole::kRDMAReceiver) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1789
1790
            // Receive from RDMA ranks and write to the output tensor
            // Clean shared memory and sync
lijian6's avatar
lijian6 committed
1791
            EP_DEVICE_ASSERT(kNumRDMARanks <= kWarpSize);
lijian6's avatar
lijian6 committed
1792
1793
1794
1795
            lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[target_rank][lane_id] = 0) : 0;
            lane_id == 0 ? (rdma_receiver_retired[target_rank] = false) : 0;
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1796
1797
1798

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
1799
            get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1800
1801
1802

            // Iterate over all tokens and combine
            int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
1803
            for(int64_t token_idx = token_start_idx + target_rank; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1804
                // Read expected head
lijian6's avatar
lijian6 committed
1805
                EP_STATIC_ASSERT(kNumRDMARanks <= kWarpSize, "Invalid number of RDMA peers");
Chenggang Zhao's avatar
Chenggang Zhao committed
1806
                int expected_head = -1;
lijian6's avatar
lijian6 committed
1807
1808
1809
1810
                if(lane_id < kNumRDMARanks) {
                    expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id);
                    (expected_head < 0) ? (rdma_receiver_rdma_head[target_rank][lane_id] = -expected_head - 1)
                                        : (rdma_receiver_rdma_head[target_rank][lane_id] = expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1811
1812
1813
                }

                // Wait lanes to be ready
lijian6's avatar
lijian6 committed
1814
                auto start_time = wall_clock64();
Chenggang Zhao's avatar
Chenggang Zhao committed
1815
                while (cached_channel_tail_idx <= expected_head) {
lijian6's avatar
lijian6 committed
1816
                    cached_channel_tail_idx = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id)));
Chenggang Zhao's avatar
Chenggang Zhao committed
1817
1818

                    // Timeout check
lijian6's avatar
lijian6 committed
1819
1820
1821
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n",
                                channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
1822
1823
1824
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
1825
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
1826
1827

                // Combine current token
lijian6's avatar
lijian6 committed
1828
1829
1830
1831
1832
1833
1834
1835
                auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast<const int4*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);};
                auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast<const float*>(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);};
                combine_token<kNumRDMARanks, dtype_t, kNumTopkRDMARanks>(expected_head >= 0,
                                                                            expected_head, lane_id,
                                                                            hidden_int4, num_topk,
                                                                            combined_x + token_idx * hidden_int4,
                                                                            combined_topk_weights + token_idx * num_topk,
                                                                            num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn);
Chenggang Zhao's avatar
Chenggang Zhao committed
1836
1837
1838
            }

            // Retired
lijian6's avatar
lijian6 committed
1839
            syncwarp();
lijian6's avatar
lijian6 committed
1840
1841
1842
1843
            if(lane_id == 0) {
                rdma_receiver_retired[target_rank] = true;
            }
        } else if(warp_role == WarpRole::kNVLCoordinator) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1844
1845
            // Coordinator
            // Sync shared memory status
lijian6's avatar
lijian6 committed
1846
1847
            // sync_rdma_receiver_smem();
            __syncthreads();
Chenggang Zhao's avatar
Chenggang Zhao committed
1848
1849
            const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks;

lijian6's avatar
lijian6 committed
1850
            int last_rdma_head               = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
1851
            int last_nvl_head[kNumRDMARanks] = {0};
lijian6's avatar
lijian6 committed
1852
1853
            int dst_rdma_rank                = lane_id < kNumRDMARanks ? lane_id : 0;
            int dst_nvl_rank                 = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
lijian6's avatar
lijian6 committed
1854
1855
1856
            EP_STATIC_ASSERT(kNumCombineForwarderWarps <= kWarpSize, "Invalid number of forwarder warps");

            while(true) {
Chenggang Zhao's avatar
Chenggang Zhao committed
1857
                // Retired
lijian6's avatar
lijian6 committed
1858
                if(__all_sync(kFullWarpMask, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id]))
Chenggang Zhao's avatar
Chenggang Zhao committed
1859
1860
1861
                    break;

                // Find minimum head for RDMA ranks
lijian6's avatar
lijian6 committed
1862
                {
Chenggang Zhao's avatar
Chenggang Zhao committed
1863
                    int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
1864
1865
1866
    #pragma unroll
                    for(int i = 0; i < kNumRDMAReceivers; ++i)
                        if(not rdma_receiver_retired[i])
lijian6's avatar
lijian6 committed
1867
                            min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
lijian6's avatar
lijian6 committed
1868
1869

                    if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
lijian6's avatar
lijian6 committed
1870
1871
1872
1873
                        rocshmem::rocshmem_ctx_ulong_atomic_add(
                            ctx, rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
                            translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));

1874
1875
                        last_rdma_head = min_head;
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
1876
1877
1878
                }

                // Nanosleep and let other warps work
lijian6's avatar
lijian6 committed
1879
                __builtin_amdgcn_s_sleep(NUM_WAIT_CYCLES_TIMES_64);
Chenggang Zhao's avatar
Chenggang Zhao committed
1880
1881
1882
            }
        }
    }
lijian6's avatar
lijian6 committed
1883
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
Chenggang Zhao's avatar
Chenggang Zhao committed
1884
1885
}

lijian6's avatar
lijian6 committed
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
void combine(hipDataType type, void *combined_x, float *combined_topk_weights,
             const bool *is_combined_token_in_rank, const void *x, const float *topk_weights,
             const void *bias_0, const void *bias_1, const int *combined_rdma_head,
             const int *combined_nvl_head, const void *src_meta,
             const int *rdma_channel_prefix_matrix, const int *rdma_rank_prefix_sum,
             const int *gbl_channel_prefix_matrix, int num_tokens, int num_combined_tokens,
             int hidden, int num_topk, void *rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
             int num_max_rdma_chunked_recv_tokens, void **buffer_ptrs,
             int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank,
             int num_ranks, hipStream_t stream, int num_channels, bool low_latency_mode) {
lijian6's avatar
lijian6 committed
1896
    constexpr int kNumCombineForwarderWarps = 8;
lijian6's avatar
lijian6 committed
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915

#define COMBINE_LAUNCH_CASE(num_rdma_ranks)                                                        \
    {                                                                                              \
        auto combine_func =                                                                        \
            low_latency_mode                                                                       \
                ? combine<true, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>           \
                : combine<false, num_rdma_ranks, hip_bfloat16, kNumCombineForwarderWarps>;         \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine_func, reinterpret_cast<int4 *>(combined_x), combined_topk_weights,       \
            is_combined_token_in_rank, reinterpret_cast<const int4 *>(x), topk_weights,            \
            reinterpret_cast<const int4 *>(bias_0), reinterpret_cast<const int4 *>(bias_1),        \
            combined_rdma_head, combined_nvl_head, reinterpret_cast<const SourceMeta *>(src_meta), \
            rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,           \
            num_tokens, num_combined_tokens, hidden, num_topk, rdma_buffer_ptr,                    \
            num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, buffer_ptrs,       \
            num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, rank, num_ranks);    \
    }                                                                                              \
    break

lijian6's avatar
lijian6 committed
1916
    int num_rdma_ranks           = num_ranks / NUM_MAX_NVL_PEERS;
Chenggang Zhao's avatar
Chenggang Zhao committed
1917
    auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1);
lijian6's avatar
lijian6 committed
1918
1919
    int num_forwarder_warps      = num_rdma_ranks * num_warps_per_forwarder;
    EP_HOST_ASSERT(num_forwarder_warps >= NUM_MAX_NVL_PEERS);
lijian6's avatar
lijian6 committed
1920
    EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
1921
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0);
lijian6's avatar
lijian6 committed
1922
    EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens));
lijian6's avatar
lijian6 committed
1923
    EP_HOST_ASSERT(type == HIP_R_16BF);
Chenggang Zhao's avatar
Chenggang Zhao committed
1924

lijian6's avatar
lijian6 committed
1925
    SETUP_LAUNCH_CONFIG(num_channels * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL, (NUM_MAX_NVL_PEERS + 1) * kWarpSize, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
1926
1927
1928
    SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
1929

Chenggang Zhao's avatar
Chenggang Zhao committed
1930
1931
1932
} // namespace internode

} // namespace deep_ep
lijian6's avatar
lijian6 committed
1933

lijian6's avatar
lijian6 committed
1934
1935
1936
// #ifdef __clang__
// #pragma clang diagnostic pop
// #endif // __clang__
lijian6's avatar
lijian6 committed
1937
1938

#endif