intranode.cu 48.7 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
#include "buffer.cuh"
lijian6's avatar
lijian6 committed
2
3
#include "configs.cuh"
#include "hip/hip_runtime.h"
Chenggang Zhao's avatar
Chenggang Zhao committed
4
5
6
7
8
9
10
#include "launch.cuh"
#include "utils.cuh"

namespace deep_ep {

namespace intranode {

lijian6's avatar
lijian6 committed
11
template <int kNumRanks>
Chenggang Zhao's avatar
Chenggang Zhao committed
12
__global__ void
lijian6's avatar
lijian6 committed
13
14
15
16
17
18
19
notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped,
                const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                int64_t *moe_recv_tokens_per_experts, int num_experts, int num_tokens,
                int num_channels, const bool *is_token_in_rank, int *channel_prefix_matrix,
                int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
                void **buffer_ptrs, int **barrier_signal_ptrs, int rank) {
    auto sm_id     = static_cast<int>(blockIdx.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
20
    auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
21
22
    auto lane_id = thread_id % kWarpSize, warp_id = thread_id / kWarpSize,
         num_warps = num_threads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
23
24
25

    if (sm_id == 0) {
        // Barrier first
26
        barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
27
28
29

        int *per_rank_buffer, *per_expert_buffer;
        if (thread_id < kNumRanks) {
lijian6's avatar
lijian6 committed
30
            per_rank_buffer   = static_cast<int *>(buffer_ptrs[thread_id]);
Chenggang Zhao's avatar
Chenggang Zhao committed
31
32
33
34
35
            per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks;
        }

        // After this loop:
        //  - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j
lijian6's avatar
lijian6 committed
36
37
        //  - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert
        //  j
Chenggang Zhao's avatar
Chenggang Zhao committed
38
39
        int num_experts_per_rank = num_experts / kNumRanks;
        if (thread_id < kNumRanks) {
Chongchong Tian's avatar
Chongchong Tian committed
40
            per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id];
lijian6's avatar
lijian6 committed
41
42
43
44
#pragma unroll
            for (int i = 0; i < num_experts_per_rank; ++i)
                per_expert_buffer[rank * num_experts_per_rank + i] =
                    num_tokens_per_expert[thread_id * num_experts_per_rank + i];
Chenggang Zhao's avatar
Chenggang Zhao committed
45
46
47
        }

        // Wait for all ranks to be finished
48
        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
49
50
51

        // Sum per-rank counts and return to CPU
        // Also pre-compute the prefix sum for data sending
lijian6's avatar
lijian6 committed
52
        auto local_per_rank_buffer = static_cast<int *>(buffer_ptrs[rank]);
Chenggang Zhao's avatar
Chenggang Zhao committed
53
        if (thread_id < kNumRanks) {
lijian6's avatar
lijian6 committed
54
55
56
57
#pragma unroll
            for (int i = 1; i < kNumRanks; ++i)
                local_per_rank_buffer[i * kNumRanks + thread_id] +=
                    local_per_rank_buffer[(i - 1) * kNumRanks + thread_id];
Chenggang Zhao's avatar
Chenggang Zhao committed
58
            if (thread_id == rank)
lijian6's avatar
lijian6 committed
59
60
                *moe_recv_counter_mapped =
                    local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank];
Chenggang Zhao's avatar
Chenggang Zhao committed
61
62
63
64
65
66
        }

        // Sum per-experts counts and return to CPU
        auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks;
        if (thread_id < num_experts_per_rank) {
            int sum = 0;
lijian6's avatar
lijian6 committed
67
68
#pragma unroll
            for (int i = 0; i < kNumRanks; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
69
70
71
                sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id];
            sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
            moe_recv_expert_counter_mapped[thread_id] = sum;
lijian6's avatar
lijian6 committed
72
            moe_recv_tokens_per_experts[thread_id]    = sum;
Chenggang Zhao's avatar
Chenggang Zhao committed
73
74
75
        }
        __syncthreads();

lijian6's avatar
lijian6 committed
76
77
// Copy rank size prefix matrix to another tensor
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
78
79
80
        for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
            rank_prefix_matrix_copy[i] = local_per_rank_buffer[i];

lijian6's avatar
lijian6 committed
81
82
// Extra memset for later communication queue
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
83
84
85
86
        for (int i = thread_id; i < num_memset_int; i += num_threads)
            local_per_expert_buffer[i] = 0;

        // Barrier
87
        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
88
89
90
91
    } else {
        int dst_rank = sm_id - 1;
        for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
92
93
            get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx,
                                   token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
94
95
96

            // Iterate over tokens
            int count = 0;
lijian6's avatar
lijian6 committed
97
            for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += kWarpSize)
Chenggang Zhao's avatar
Chenggang Zhao committed
98
99
                count += is_token_in_rank[i * kNumRanks + dst_rank];
            count = warp_reduce_sum(count);
lijian6's avatar
lijian6 committed
100
            if (lane_id == 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
101
102
103
104
105
106
                channel_prefix_matrix[dst_rank * num_channels + channel_id] = count;
        }
        __syncthreads();

        // Pre-compute prefix sum for all channels
        if (thread_id == 0) {
lijian6's avatar
lijian6 committed
107
108
109
110
#pragma unroll
            for (int i = 1; i < num_channels; ++i)
                channel_prefix_matrix[dst_rank * num_channels + i] +=
                    channel_prefix_matrix[dst_rank * num_channels + i - 1];
Chenggang Zhao's avatar
Chenggang Zhao committed
111
112
113
114
        }
    }
}

lijian6's avatar
lijian6 committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
void notify_dispatch(const int *num_tokens_per_rank, int *moe_recv_counter_mapped, int num_ranks,
                     const int *num_tokens_per_expert, int *moe_recv_expert_counter_mapped,
                     int64_t *moe_recv_tokens_per_experts, int num_experts, int num_tokens,
                     const bool *is_token_in_rank, int *channel_prefix_matrix,
                     int *rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
                     void **buffer_ptrs, int **barrier_signal_ptrs, int rank, hipStream_t stream,
                     int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks)                                                         \
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, notify_dispatch<ranks>, num_tokens_per_rank,               \
                                  moe_recv_counter_mapped, num_tokens_per_expert,                  \
                                  moe_recv_expert_counter_mapped, moe_recv_tokens_per_experts,     \
                                  num_experts, num_tokens, num_channels, is_token_in_rank,         \
                                  channel_prefix_matrix, rank_prefix_matrix_copy, num_memset_int,  \
                                  expert_alignment, buffer_ptrs, barrier_signal_ptrs, rank);       \
Chenggang Zhao's avatar
Chenggang Zhao committed
129
130
131
132
133
134
135
136
137
138
139
    break

    constexpr int kNumThreads = 128;
    EP_HOST_ASSERT(num_experts % num_ranks == 0);
    EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);

    SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
    SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);
#undef NOTIFY_DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
140
141
142
template <int kNumRanks>
__global__ void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int,
                                       void **buffer_ptrs, int **barrier_signal_ptrs, int rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
143
    // A simplified version for cached handles
144
    barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
145
146
147

    // Copy and clean
    auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
148
149
    auto ptr = static_cast<int *>(buffer_ptrs[rank]);
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
150
151
    for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads)
        ptr[i] = rank_prefix_matrix[i];
lijian6's avatar
lijian6 committed
152
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
153
154
155
156
    for (int i = thread_id; i < num_memset_int; i += num_threads)
        ptr[kNumRanks * kNumRanks + i] = 0;

    // Barrier after cleaning
157
    barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
158
159
}

lijian6's avatar
lijian6 committed
160
161
162
163
164
165
void cached_notify_dispatch(const int *rank_prefix_matrix, int num_memset_int, void **buffer_ptrs,
                            int **barrier_signal_ptrs, int rank, int num_ranks,
                            hipStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks)                                                  \
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_dispatch<ranks>, rank_prefix_matrix,         \
                                  num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank);         \
Chenggang Zhao's avatar
Chenggang Zhao committed
166
167
    break

lijian6's avatar
lijian6 committed
168
    SETUP_LAUNCH_CONFIG(1, 256, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
169
170
171
172
    SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE);
#undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE
}

lijian6's avatar
lijian6 committed
173
template <int kNumRanks, int kNumThreads>
Chenggang Zhao's avatar
Chenggang Zhao committed
174
__global__ void __launch_bounds__(kNumThreads, 1)
lijian6's avatar
lijian6 committed
175
176
177
178
179
180
181
    dispatch(int4 *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
             float *recv_topk_weights, int *recv_channel_offset, int *send_head, const int4 *x,
             const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
             const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
             int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
             int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
             int num_max_send_tokens, int num_recv_buffer_tokens) {
Chenggang Zhao's avatar
Chenggang Zhao committed
182
    const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
183
    const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
Chenggang Zhao's avatar
Chenggang Zhao committed
184
185
186
    const bool is_sender = sm_id % 2 == 0;
    EP_DEVICE_ASSERT(num_sms % 2 == 0);

Chenggang Zhao's avatar
Chenggang Zhao committed
187
188
    // Several warps are response for a single rank
    const auto num_threads_per_rank = kNumThreads / kNumRanks;
lijian6's avatar
lijian6 committed
189
190
    const auto num_channels         = num_sms / 2;
    const auto responsible_rank     = (static_cast<int>(thread_id)) / num_threads_per_rank;
Chenggang Zhao's avatar
Chenggang Zhao committed
191
    // Even-numbered blocks for sending, odd-numbered blocks for receiving.
Chenggang Zhao's avatar
Chenggang Zhao committed
192
193
194
195
    const auto responsible_channel = sm_id / 2;

    int num_experts_per_rank = num_experts / kNumRanks;
    EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0);
lijian6's avatar
lijian6 committed
196
197
    EP_DEVICE_ASSERT(num_topk <= kWarpSize);
    EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr));
Chenggang Zhao's avatar
Chenggang Zhao committed
198
199
200
201
    EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr));

    // Calculate pointers by the specific layout
    // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int)
lijian6's avatar
lijian6 committed
202
203
204
205
206
    auto ptr = reinterpret_cast<void *>(
        static_cast<int8_t *>(buffer_ptrs[is_sender ? responsible_rank : rank]) +
        kNumRanks * kNumRanks * sizeof(int));
    int  target_rank         = is_sender ? rank : responsible_rank;
    auto num_channels_total  = num_channels * kNumRanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
207
208
209
210
211
    auto channel_rank_offset = responsible_channel * kNumRanks + target_rank;

    // Channel buffer metadata
    // Senders are responsible for tails, and receivers are responsible for heads
    // Stored on the receiver side
lijian6's avatar
lijian6 committed
212
213
214
215
    // The retired signals are actually boolean flags, but to align with 16 bytes, we make it
    // `int64_t` `start_offset`: kNumChannels * kNumRanks * sizeof(int) `end_offset`: kNumChannels *
    // kNumRanks * sizeof(int) `head_idx`: kNumChannels * kNumRanks * sizeof(int) `tail_idx`:
    // kNumChannels * kNumRanks * sizeof(int)
Chenggang Zhao's avatar
Chenggang Zhao committed
216
    auto channel_start_offset = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
lijian6's avatar
lijian6 committed
217
218
219
    auto channel_end_offset   = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
    auto channel_head_idx     = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
    auto channel_tail_idx     = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
Chenggang Zhao's avatar
Chenggang Zhao committed
220
221
222
223

    // Channel data buffers, stored on the receiver side
    // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4)
    // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int)
lijian6's avatar
lijian6 committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk *
    // sizeof(int64_t) `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
    // num_topk * sizeof(float) `x_scales_buffers`: kNumChannels * kNumRanks *
    // num_recv_buffer_tokens * num_scales * sizeof(float)
    auto channel_x_buffers =
        Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
                     channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
    auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens,
                                               channel_rank_offset * num_recv_buffer_tokens);
    auto channel_topk_idx_buffers =
        Buffer<int64_t>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
                        channel_rank_offset * num_recv_buffer_tokens * num_topk);
    auto channel_topk_weights_buffers =
        Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
                      channel_rank_offset * num_recv_buffer_tokens * num_topk);
    auto channel_x_scales_buffers =
        Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_scales,
                      channel_rank_offset * num_recv_buffer_tokens * num_scales);
242

Chenggang Zhao's avatar
Chenggang Zhao committed
243
244
    if (is_sender) {
        // Workers for sending
lijian6's avatar
lijian6 committed
245
        constexpr int num_send_warps          = kNumThreads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
246
        constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks;
lijian6's avatar
lijian6 committed
247
248
249
        const auto    send_thread_id          = thread_id;
        const auto    send_warp_id_in_rank    = send_thread_id % num_threads_per_rank / kWarpSize;
        EP_DEVICE_ASSERT(kNumRanks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
250
        EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
251
252
253

        // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2
        // NOTES: this is for distinguishing zero tokens
lijian6's avatar
lijian6 committed
254
255
256
257
258
        if (lane_id == 0 and send_warp_id_in_rank == 0) {
            int value = responsible_channel > 0
                            ? channel_prefix_matrix[responsible_rank * num_channels +
                                                    responsible_channel - 1]
                            : 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
259
            st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1);
Chenggang Zhao's avatar
Chenggang Zhao committed
260
            value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel];
Chenggang Zhao's avatar
Chenggang Zhao committed
261
262
            st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1);
        }
lijian6's avatar
lijian6 committed
263
        syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
264
265
266

        // Get tasks
        int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
267
268
        get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx,
                               token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
269
270
271

        // Iterate over all tokens and send by chunks
        int cached_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
272
        for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
Chenggang Zhao's avatar
Chenggang Zhao committed
273
            // Check destination queue emptiness, or wait a buffer to be released (rare cases)
Chenggang Zhao's avatar
Chenggang Zhao committed
274
            // NOTES: the head index received by different warps may not be the same
Chenggang Zhao's avatar
Chenggang Zhao committed
275
            auto start_time = clock64();
lijian6's avatar
lijian6 committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            while (lane_id == 0) {
                // NOTES: we only consider the worst case, because counting the real numbers are
                // time-consuming
                int num_used_slots =
                    cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
                if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens)
                    break;

                // Rare cases to loop again
                if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf(
                        "DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n",
                        rank, responsible_channel);
                    trap();
Chenggang Zhao's avatar
Chenggang Zhao committed
290
291
                }
            }
lijian6's avatar
lijian6 committed
292
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
293
294
295

            int chunk_token_idx = 0;
            while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) {
lijian6's avatar
lijian6 committed
296
297
298
299
300
301
302
                // NOTES: for the same token, the warp assigned to save `send_head` may be different
                // from the warp assigned to send the following data
                if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank)
                    send_head[token_idx * kNumRanks + responsible_rank] =
                        is_token_in_rank[token_idx * kNumRanks + responsible_rank]
                            ? cached_channel_tail_idx
                            : -1;
Chenggang Zhao's avatar
Chenggang Zhao committed
303

Chenggang Zhao's avatar
Chenggang Zhao committed
304
                // Skip if not selected
Chenggang Zhao's avatar
Chenggang Zhao committed
305
                if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) {
lijian6's avatar
lijian6 committed
306
                    token_idx++;
Chenggang Zhao's avatar
Chenggang Zhao committed
307
308
309
310
                    continue;
                }

                // Get an empty slot
lijian6's avatar
lijian6 committed
311
                int dst_slot_idx = (cached_channel_tail_idx++) % num_recv_buffer_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
312
313
                if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) {
                    // Copy data
lijian6's avatar
lijian6 committed
314
315
                    auto shifted_channel_x_buffers =
                        channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
Chenggang Zhao's avatar
Chenggang Zhao committed
316
                    auto shifted_x = x + token_idx * hidden_int4;
lijian6's avatar
lijian6 committed
317
318
                    UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_channel_x_buffers,
                                       shifted_x, __ldg, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
319
320

                    // Copy source index
lijian6's avatar
lijian6 committed
321
                    if (lane_id == 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
322
323
324
                        channel_src_idx_buffers[dst_slot_idx] = static_cast<int>(token_idx);

                    // Copy `topk_idx` and `topk_weights` with transformed index
325
                    if (lane_id < num_topk) {
Chenggang Zhao's avatar
Chenggang Zhao committed
326
                        // Top-k index
lijian6's avatar
lijian6 committed
327
328
329
330
331
332
                        int recv_expert_begin = responsible_rank * num_experts_per_rank,
                            recv_expert_end   = (responsible_rank + 1) * num_experts_per_rank;
                        auto idx_value        = __ldg(topk_idx + token_idx * num_topk + lane_id);
                        idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end)
                                        ? idx_value - recv_expert_begin
                                        : -1;
333
                        channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value;
Chenggang Zhao's avatar
Chenggang Zhao committed
334
335

                        // Top-k weights
336
                        auto weight_value = __ldg(topk_weights + token_idx * num_topk + lane_id);
lijian6's avatar
lijian6 committed
337
338
339
                        weight_value      = (idx_value >= 0) ? weight_value : 0.0f;
                        channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] =
                            weight_value;
Chenggang Zhao's avatar
Chenggang Zhao committed
340
341
                    }

lijian6's avatar
lijian6 committed
342
343
344
// Copy `x_scales`
#pragma unroll
                    for (int i = lane_id; i < num_scales; i += kWarpSize) {
Shifang Xu's avatar
Shifang Xu committed
345
                        auto offset = token_idx * scale_token_stride + i * scale_hidden_stride;
lijian6's avatar
lijian6 committed
346
347
                        channel_x_scales_buffers[dst_slot_idx * num_scales + i] =
                            __ldg(x_scales + offset);
Shifang Xu's avatar
Shifang Xu committed
348
                    }
Chenggang Zhao's avatar
Chenggang Zhao committed
349
350
351
                }

                // Move token index
lijian6's avatar
lijian6 committed
352
                chunk_token_idx++, token_idx++;
Chenggang Zhao's avatar
Chenggang Zhao committed
353
354
355
            }

            // Move tail index
Chenggang Zhao's avatar
Chenggang Zhao committed
356
            // NOTES: here all warps should share the same new tail
lijian6's avatar
lijian6 committed
357
358
359
360
361
362
363
            if (num_threads_per_rank > kWarpSize) {
                __syncthreads();
            } else {
                syncwarp();
            }
            if (send_warp_id_in_rank == 0 and lane_id == 0)
                st_relaxed_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
364
365
366
        }
    } else {
        // Workers for receiving and copying into buffer
lijian6's avatar
lijian6 committed
367
        constexpr int num_recv_warps          = kNumThreads / kWarpSize;
Chenggang Zhao's avatar
Chenggang Zhao committed
368
        constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks;
lijian6's avatar
lijian6 committed
369
370
371
372
        const auto    recv_thread_id          = thread_id;
        const auto    recv_thread_id_in_rank  = recv_thread_id % num_threads_per_rank;
        const auto    recv_warp_id_in_rank    = recv_thread_id_in_rank / kWarpSize;
        EP_DEVICE_ASSERT(kNumRanks <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
373
        EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
374
375

        // Calculate offset first
lijian6's avatar
lijian6 committed
376
377
378
379
        auto rank_prefix_matrix = static_cast<int *>(buffer_ptrs[rank]);
        int  rank_offset        = responsible_rank > 0
                                      ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank]
                                      : 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
380
381
382

        // Receive channel offset
        int total_offset, num_tokens_to_recv;
lijian6's avatar
lijian6 committed
383
384
385
386
387
388
389
        while (lane_id == 0 and
               (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0)
            ;
        while (lane_id == 0 and
               (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0)
            ;
        if (lane_id == 0) {
Chenggang Zhao's avatar
Chenggang Zhao committed
390
            total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
391
            if (recv_warp_id_in_rank == 0)
lijian6's avatar
lijian6 committed
392
393
                recv_channel_offset[responsible_rank * num_channels + responsible_channel] =
                    total_offset;
Chenggang Zhao's avatar
Chenggang Zhao committed
394
395
            num_tokens_to_recv -= total_offset;
        }
lijian6's avatar
lijian6 committed
396
        total_offset = shfl_sync(total_offset, 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
397
        total_offset += rank_offset;
lijian6's avatar
lijian6 committed
398
        num_tokens_to_recv = shfl_sync(num_tokens_to_recv, 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
399

Chenggang Zhao's avatar
Chenggang Zhao committed
400
401
402
        // Shared tail indices for different warps
        __shared__ volatile int shared_channel_tail_idx[kNumRanks];

lijian6's avatar
lijian6 committed
403
404
        auto start_time              = clock64();
        int  cached_channel_head_idx = 0, cached_channel_tail_idx = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
405
        while (num_tokens_to_recv > 0) {
lijian6's avatar
lijian6 committed
406
407
            // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by
            // different warps are the same
Chenggang Zhao's avatar
Chenggang Zhao committed
408
            while (recv_thread_id_in_rank == 0) {
lijian6's avatar
lijian6 committed
409
                cached_channel_tail_idx = ld_relaxed_sys_global(channel_tail_idx.buffer());
Chenggang Zhao's avatar
Chenggang Zhao committed
410
411

                // Ready to copy
Chenggang Zhao's avatar
Chenggang Zhao committed
412
413
                if (cached_channel_head_idx != cached_channel_tail_idx) {
                    shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
414
                    break;
Chenggang Zhao's avatar
Chenggang Zhao committed
415
                }
Chenggang Zhao's avatar
Chenggang Zhao committed
416
417
418

                // Timeout check
                if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
lijian6's avatar
lijian6 committed
419
420
421
                    printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = "
                           "%d, tokens remained: %d\n",
                           rank, responsible_channel, num_tokens_to_recv);
Chenggang Zhao's avatar
Chenggang Zhao committed
422
423
424
425
                    trap();
                }
            }

Chenggang Zhao's avatar
Chenggang Zhao committed
426
            // Synchronize queue tail
lijian6's avatar
lijian6 committed
427
428
429
430
431
            if (num_threads_per_rank > kWarpSize) {
                __syncthreads();
            } else {
                syncwarp();
            }
Chenggang Zhao's avatar
Chenggang Zhao committed
432
            cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank];
Chenggang Zhao's avatar
Chenggang Zhao committed
433
434
435

            // Copy data
            int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx;
lijian6's avatar
lijian6 committed
436
437
438
439
440
441
442
443
444
445
            for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens;
                 chunk_idx += num_recv_warps_per_rank) {
                int token_idx_in_buffer =
                    (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
                auto shifted_buffer_x_int4 =
                    channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4;
                auto shifted_recv_x_int4 =
                    recv_x + static_cast<int64_t>(total_offset + chunk_idx) * hidden_int4;
                UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_recv_x_int4,
                                   shifted_buffer_x_int4, ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
446
447
            }

lijian6's avatar
lijian6 committed
448
449
450
451
452
453
454
455
456
457
458
459
// Copy `src_idx`
#pragma unroll 4
            for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank;
                 chunk_idx < cached_channel_tail_idx;
                 chunk_idx += kWarpSize * num_recv_warps_per_rank)
                recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(
                    channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens);

// Copy `topk_idx` and `topk_weights`
#pragma unroll 4
            for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk;
                 idx += kWarpSize * num_recv_warps_per_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
460
                int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk;
lijian6's avatar
lijian6 committed
461
462
463
464
                int token_idx_in_buffer =
                    (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
                auto recv_idx =
                    static_cast<int64_t>(total_offset + chunk_idx) * num_topk + token_topk_idx;
Chenggang Zhao's avatar
Chenggang Zhao committed
465
                auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx;
lijian6's avatar
lijian6 committed
466
467
468
469
                recv_topk_idx[recv_idx] =
                    ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx);
                recv_topk_weights[recv_idx] =
                    ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
470
471
            }

lijian6's avatar
lijian6 committed
472
473
474
475
// Copy `x_scales`
#pragma unroll 4
            for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales;
                 i += kWarpSize * num_recv_warps_per_rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
476
                int chunk_idx = i / num_scales, scales_idx = i % num_scales;
lijian6's avatar
lijian6 committed
477
478
479
480
481
482
                int token_idx_in_buffer =
                    (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens;
                recv_x_scales[static_cast<int64_t>(total_offset + chunk_idx) * num_scales +
                              scales_idx] =
                    ld_nc_global(channel_x_scales_buffers.buffer() +
                                 token_idx_in_buffer * num_scales + scales_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
483
484
485
486
487
            }

            // Move queue
            cached_channel_head_idx += num_recv_tokens;
            total_offset += num_recv_tokens;
lijian6's avatar
lijian6 committed
488
489
490
491
492
493
            if (num_threads_per_rank > kWarpSize) {
                __syncthreads();
            } else {
                syncwarp();
            }
            if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and lane_id == 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
494
495
496
497
498
499
                st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx);

            // Exit
            num_tokens_to_recv -= num_recv_tokens;
        }
    }
500
501
502

    // Clean unused `recv_topk_idx` as -1
    if (num_worst_tokens > 0) {
lijian6's avatar
lijian6 committed
503
504
505
506
507
508
        auto       rank_prefix_matrix = static_cast<int *>(buffer_ptrs[rank]);
        const auto num_recv_tokens    = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank];
        const auto clean_start        = num_recv_tokens * num_topk + sm_id * kNumThreads;
        const auto clean_end          = num_worst_tokens * num_topk;
        const auto clean_stride       = num_sms * kNumThreads;
#pragma unroll
509
510
511
        for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
            recv_topk_idx[i] = -1;
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
512
513
}

lijian6's avatar
lijian6 committed
514
515
516
517
518
519
520
521
522
void dispatch(void *recv_x, float *recv_x_scales, int *recv_src_idx, int64_t *recv_topk_idx,
              float *recv_topk_weights, int *recv_channel_offset, int *send_head, const void *x,
              const float *x_scales, const int64_t *topk_idx, const float *topk_weights,
              const bool *is_token_in_rank, const int *channel_prefix_matrix, int num_tokens,
              int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
              int scale_token_stride, int scale_hidden_stride, void **buffer_ptrs, int rank,
              int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
              int num_recv_buffer_tokens) {
    constexpr int kNumThreads = 1024;
523

Shifang Xu's avatar
Shifang Xu committed
524
    // Make sure never OOB
lijian6's avatar
lijian6 committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    EP_HOST_ASSERT(static_cast<int64_t>(num_scales) * scale_hidden_stride <
                       std::numeric_limits<int>::max());

#define DISPATCH_LAUNCH_CASE(ranks)                                                                \
    {                                                                                              \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, dispatch<ranks, kNumThreads>, reinterpret_cast<int4 *>(recv_x), recv_x_scales,   \
            recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, send_head,        \
            reinterpret_cast<const int4 *>(x), x_scales, topk_idx, topk_weights, is_token_in_rank, \
            channel_prefix_matrix, num_tokens, num_worst_tokens, hidden_int4, num_topk,            \
            num_experts, num_scales, scale_token_stride, scale_hidden_stride, buffer_ptrs, rank,   \
            num_max_send_tokens, num_recv_buffer_tokens);                                          \
    }                                                                                              \
    break
Chenggang Zhao's avatar
Chenggang Zhao committed
539
540
541

    // Even-numbered blocks for sending, odd-numbered blocks for receiving.
    EP_HOST_ASSERT(num_sms % 2 == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
542
    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
Chenggang Zhao's avatar
Chenggang Zhao committed
543
544
545
    SWITCH_RANKS(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
}
lijian6's avatar
lijian6 committed
546
547
548
549
template <int kNumRanks>
__global__ void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
                                      int num_recv_tokens, int num_memset_int,
                                      int **barrier_signal_ptrs, int rank) {
Chenggang Zhao's avatar
Chenggang Zhao committed
550
551
552
    const auto sm_id = static_cast<int>(blockIdx.x);
    if (sm_id == 0) {
        // Barrier before cleaning
553
        barrier_block<kNumRanks, true>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
554
555
556

        // Clean
        auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
lijian6's avatar
lijian6 committed
557
558
        auto ptr = static_cast<int *>(buffer_ptrs[rank]);
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
559
560
561
562
        for (int i = thread_id; i < num_memset_int; i += num_threads)
            ptr[i] = 0;

        // Barrier after cleaning
563
        barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
Chenggang Zhao's avatar
Chenggang Zhao committed
564
565
    } else {
        const auto channel_id = sm_id - 1;
lijian6's avatar
lijian6 committed
566
567
568
        const auto thread_id  = static_cast<int>(threadIdx.x);
        const auto rank_id    = thread_id / kWarpSize;
        const auto lane_id    = thread_id % kWarpSize;
569
570
        if (rank_id >= kNumRanks)
            return;
Chenggang Zhao's avatar
Chenggang Zhao committed
571
572

        int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
573
574
        get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx,
                               token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
575
576
577

        // NOTES: `1 << 25` is a heuristic large number
        int last_head = 1 << 25;
lijian6's avatar
lijian6 committed
578
579
580
581
582
583
584
585
#pragma unroll
        for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx;
             token_idx_tail -= kWarpSize) {
            int  token_idx = token_idx_tail - lane_id, expected_head = 0;
            auto current_head = (token_idx >= token_start_idx)
                                    ? __ldg(send_head + token_idx * kNumRanks + rank_id)
                                    : -1;
            for (int i = 0; i < min(kWarpSize, token_idx_tail - token_start_idx + 1); ++i) {
lijian6's avatar
lijian6 committed
586
                const int head = shfl_sync(current_head, i);
Chenggang Zhao's avatar
Chenggang Zhao committed
587
588
589
590
591
592
593
594
595
596
597
598
599
                if (head < 0) {
                    if (lane_id == i)
                        expected_head = -last_head - 1;
                } else {
                    last_head = head;
                }
            }
            if (current_head < 0 and token_idx >= token_start_idx)
                send_head[token_idx * kNumRanks + rank_id] = expected_head;
        }
    }
}

lijian6's avatar
lijian6 committed
600
601
602
603
604
605
606
void cached_notify_combine(void **buffer_ptrs, int *send_head, int num_channels,
                           int num_recv_tokens, int num_memset_int, int **barrier_signal_ptrs,
                           int rank, int num_ranks, hipStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks)                                                               \
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_combine<ranks>, buffer_ptrs, send_head,      \
                                  num_channels, num_recv_tokens, num_memset_int,                   \
                                  barrier_signal_ptrs, rank);                                      \
Chenggang Zhao's avatar
Chenggang Zhao committed
607
608
    break

lijian6's avatar
lijian6 committed
609
    const int num_threads = ::max(128, kWarpSize * num_ranks);
Chenggang Zhao's avatar
Chenggang Zhao committed
610
611
612
613
614
615
616
617
    EP_HOST_ASSERT(num_ranks <= num_threads);
    EP_HOST_ASSERT(num_threads <= 1024);
    EP_HOST_ASSERT(1 + num_channels <= num_channels * 2);
    SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream);
    SWITCH_RANKS(CACHED_NOTIFY_COMBINE);
#undef CACHED_NOTIFY_COMBINE
}

lijian6's avatar
lijian6 committed
618
template <typename dtype_t, int kNumRanks, int kNumThreads>
Chenggang Zhao's avatar
Chenggang Zhao committed
619
__global__ void __launch_bounds__(kNumThreads, 1)
lijian6's avatar
lijian6 committed
620
621
622
623
624
625
    combine(dtype_t *recv_x, float *recv_topk_weights, const dtype_t *x, const float *topk_weights,
            const dtype_t *bias_0, const dtype_t *bias_1, const int *src_idx,
            const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
            int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
            int rank, int num_max_send_tokens, int num_recv_buffer_tokens) {
    const auto num_sms   = static_cast<int>(gridDim.x);
Chenggang Zhao's avatar
Chenggang Zhao committed
626
    const auto thread_id = static_cast<int>(threadIdx.x);
627
    const auto sm_id = static_cast<int>(blockIdx.x), lane_id = get_lane_id();
lijian6's avatar
lijian6 committed
628
629
630
631
    const auto num_channels        = num_sms / 2;
    const bool is_sender           = sm_id % 2 == 0;
    const int  responsible_channel = sm_id / 2;
    EP_DEVICE_ASSERT(num_topk <= kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
632
633

    constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
lijian6's avatar
lijian6 committed
634
635
636
637
638
    int           hidden_int4   = hidden * sizeof(dtype_t) / sizeof(int4);
    auto          x_int4        = reinterpret_cast<const int4 *>(x);
    auto          bias_0_int4   = reinterpret_cast<const int4 *>(bias_0);
    auto          bias_1_int4   = reinterpret_cast<const int4 *>(bias_1);
    auto          recv_int4     = reinterpret_cast<int4 *>(recv_x);
639

Chenggang Zhao's avatar
Chenggang Zhao committed
640
641
642
    if (is_sender) {
        // Workers for sending
        // Several warps are responsible for a single rank
lijian6's avatar
lijian6 committed
643
644
645
646
647
648
649
650
        constexpr int num_send_warps_per_rank = (kNumThreads / kWarpSize) / kNumRanks;
        constexpr int num_send_warps          = num_send_warps_per_rank * kNumRanks;
        const auto    num_threads_per_rank    = num_send_warps_per_rank * kWarpSize;
        const auto    send_thread_id          = thread_id;
        const auto    send_warp_id            = send_thread_id / kWarpSize;
        const auto    send_rank_id            = (responsible_channel + send_warp_id) % kNumRanks;
        const auto    send_warp_id_in_rank    = send_warp_id / kNumRanks;
        EP_STATIC_ASSERT(num_send_warps * kWarpSize == kNumThreads, "Invalid warp count");
Chenggang Zhao's avatar
Chenggang Zhao committed
651
652

        // Calculate pointers by the specific layout
lijian6's avatar
lijian6 committed
653
654
        auto ptr = reinterpret_cast<void *>(static_cast<int8_t *>(buffer_ptrs[send_rank_id]));
        auto num_channels_total  = num_channels * kNumRanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
655
656
657
658
659
        auto channel_rank_offset = responsible_channel * kNumRanks + rank;

        // Channel meta data
        // `head_idx`: kNumChannels * kNumRanks * sizeof(int)
        // `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
lijian6's avatar
lijian6 committed
660
661
662
663
664
        // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 *
        // sizeof(int4) `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
        // sizeof(int) `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
        // num_topk * sizeof(float)

Chenggang Zhao's avatar
Chenggang Zhao committed
665
666
        auto channel_head_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
        auto channel_tail_idx = Buffer<int>(ptr, num_channels_total, channel_rank_offset);
lijian6's avatar
lijian6 committed
667
668
669
670
671
672
673
674
        auto channel_x_buffers =
            Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
                         channel_rank_offset * num_recv_buffer_tokens * hidden_int4);
        auto channel_src_idx_buffers = Buffer<int>(ptr, num_channels_total * num_recv_buffer_tokens,
                                                   channel_rank_offset * num_recv_buffer_tokens);
        auto channel_topk_weights_buffers =
            Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
                          channel_rank_offset * num_recv_buffer_tokens * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
675
676
677

        // Get tasks
        // NOTES: `channel_offset` is already shifted
lijian6's avatar
lijian6 committed
678
679
        int rank_offset =
            send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
680
        int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset;
lijian6's avatar
lijian6 committed
681
682
683
684
685
686
687
688
689
        int channel_offset =
            channel_prefix_matrix[send_rank_id * num_channels + responsible_channel];
        int num_channel_tokens =
            (responsible_channel == num_channels - 1
                 ? num_rank_tokens
                 : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) -
            channel_offset;
        int token_start_idx = rank_offset + channel_offset,
            token_end_idx   = rank_offset + channel_offset + num_channel_tokens;
Chenggang Zhao's avatar
Chenggang Zhao committed
690
691
692

        // Iterate over all tokens and send by chunks
        int current_channel_tail_idx = 0;
lijian6's avatar
lijian6 committed
693
        for (int64_t token_idx = token_start_idx; token_idx < token_end_idx;) {
Chenggang Zhao's avatar
Chenggang Zhao committed
694
            // Check destination queue emptiness, or wait a buffer to be released (rare cases)
lijian6's avatar
lijian6 committed
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
            auto start_time = wall_clock64();
            int  num_round_tokens =
                min(num_max_send_tokens, token_end_idx - static_cast<int>(token_idx));
            while (lane_id == 0) {
                // NOTES: we only consider the worst case, because counting the real numbers are
                // time-consuming
                int num_used_slots =
                    current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer());
                if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens)
                    break;

                // Rare cases to loop again
                if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                    printf(
                        "DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n",
                        rank, responsible_channel);
                    trap();
Chenggang Zhao's avatar
Chenggang Zhao committed
712
713
                }
            }
lijian6's avatar
lijian6 committed
714
            syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
715

lijian6's avatar
lijian6 committed
716
717
// Send by chunk
#pragma unroll
Chenggang Zhao's avatar
Chenggang Zhao committed
718
719
720
721
722
723
            for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) {
                // Get an empty slot
                int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens;

                // Copy data
                auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4;
lijian6's avatar
lijian6 committed
724
725
726
                auto shifted_x         = x_int4 + (token_idx + i) * hidden_int4;
                UNROLLED_WARP_COPY(2, lane_id, hidden_int4, shifted_x_buffers, shifted_x,
                                   ld_nc_global, st_na_global);
Chenggang Zhao's avatar
Chenggang Zhao committed
727
728

                // Send source index
lijian6's avatar
lijian6 committed
729
                if (lane_id == 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
730
731
732
                    channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i);

                // Send `topk_weights`
733
                if (num_topk > 0 and lane_id < num_topk)
lijian6's avatar
lijian6 committed
734
735
                    channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] =
                        __ldg(topk_weights + (token_idx + i) * num_topk + lane_id);
Chenggang Zhao's avatar
Chenggang Zhao committed
736
737
738
739
740
            }
            token_idx += num_round_tokens;
            current_channel_tail_idx += num_round_tokens;

            // Move tail index
lijian6's avatar
lijian6 committed
741
742
743
744
745
746
747
            if (num_threads_per_rank > kWarpSize) {
                __syncthreads();
            } else {
                syncwarp();
            }
            if (lane_id == 0 and send_warp_id_in_rank == 0)
                st_relaxed_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
748
749
750
751
        }
    } else {
        // Workers for receiving
        // One warp for moving the queue head, others for reduction
lijian6's avatar
lijian6 committed
752
753
754
755
        constexpr int num_recv_warps = kNumThreads / kWarpSize;
        const auto    recv_warp_id   = thread_id / kWarpSize;
        EP_DEVICE_ASSERT(kNumRanks <= kWarpSize and kNumThreads > kWarpSize);
        EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % kWarpSize == 0);
Chenggang Zhao's avatar
Chenggang Zhao committed
756
757

        // Shared head, tail and retired flags for receiver warps
lijian6's avatar
lijian6 committed
758
759
        __shared__ volatile int  warp_channel_head_idx[num_recv_warps][kNumRanks];
        __shared__ volatile int  channel_tail_idx[kNumRanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
760
761
762
        __shared__ volatile bool warp_retired[num_recv_warps];
        if (thread_id < num_recv_warps)
            warp_retired[thread_id] = false;
763
764
        if (lane_id < kNumRanks)
            warp_channel_head_idx[recv_warp_id][lane_id] = 0;
Chenggang Zhao's avatar
Chenggang Zhao committed
765
766
767
        if (thread_id < kNumRanks)
            channel_tail_idx[thread_id] = 0;

lijian6's avatar
lijian6 committed
768
769
770
771
772
773
        __syncthreads();

        if (thread_id < kWarpSize) {
            int *channel_head_idx_ptr =
                static_cast<int *>(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id;
            int *channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
774
775
776

            // Queue head updater
            int last_head = 0;
777
            while (lane_id < kNumRanks) {
Chenggang Zhao's avatar
Chenggang Zhao committed
778
779
                // Check retired
                bool retired = true;
lijian6's avatar
lijian6 committed
780
781
#pragma unroll
                for (int i = 1; i < num_recv_warps; ++i)
Chenggang Zhao's avatar
Chenggang Zhao committed
782
783
784
785
786
                    retired = retired and warp_retired[i];
                if (retired)
                    break;

                // Update queue tail
lijian6's avatar
lijian6 committed
787
                channel_tail_idx[lane_id] = ld_relaxed_sys_global(channel_tail_idx_ptr);
Chenggang Zhao's avatar
Chenggang Zhao committed
788
789
790

                // Update minimum head
                int min_head = std::numeric_limits<int>::max();
lijian6's avatar
lijian6 committed
791
792
793
794
#pragma unroll
                for (int i = 1; i < num_recv_warps; ++i)
                    if (not warp_retired[i])
                        min_head = min(min_head, warp_channel_head_idx[i][lane_id]);
Chenggang Zhao's avatar
Chenggang Zhao committed
795
796
797
798
799
800
801
                if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
                    st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
            }
        } else {
            // Receivers
            // Channel metadata
            // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx`
lijian6's avatar
lijian6 committed
802
            Buffer<int4>  channel_x_buffers[kNumRanks];
Chenggang Zhao's avatar
Chenggang Zhao committed
803
804
            Buffer<float> channel_topk_weights_buffers[kNumRanks];

lijian6's avatar
lijian6 committed
805
806
807
// Calculate pointers by the specific layout
#pragma unroll
            for (int i = 0; i < kNumRanks; ++i) {
Chenggang Zhao's avatar
Chenggang Zhao committed
808
                auto channel_rank_offset = responsible_channel * kNumRanks + i;
lijian6's avatar
lijian6 committed
809
                auto num_channels_total  = num_channels * kNumRanks;
Chenggang Zhao's avatar
Chenggang Zhao committed
810
                // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int)
lijian6's avatar
lijian6 committed
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
                auto ptr = reinterpret_cast<void *>(static_cast<int8_t *>(buffer_ptrs[rank]) +
                                                    2 * num_channels * kNumRanks * sizeof(int));

                // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 *
                // sizeof(int4)
                channel_x_buffers[i] =
                    Buffer<int4>(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4,
                                 channel_rank_offset * num_recv_buffer_tokens * hidden_int4);

                // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
                // sizeof(int)
                ptr = reinterpret_cast<void *>(static_cast<int8_t *>(ptr) +
                                               num_channels_total * num_recv_buffer_tokens *
                                                   sizeof(int));

                // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens *
                // num_topk * sizeof(float)
                channel_topk_weights_buffers[i] =
                    Buffer<float>(ptr, num_channels_total * num_recv_buffer_tokens * num_topk,
                                  channel_rank_offset * num_recv_buffer_tokens * num_topk);
Chenggang Zhao's avatar
Chenggang Zhao committed
831
832
833
834
            }

            // The same tokens as the dispatch process
            int token_start_idx, token_end_idx;
lijian6's avatar
lijian6 committed
835
836
            get_channel_task_range(num_recv_tokens, num_channels, responsible_channel,
                                   token_start_idx, token_end_idx);
Chenggang Zhao's avatar
Chenggang Zhao committed
837
838

            // Iterate over all tokens and combine
lijian6's avatar
lijian6 committed
839
840
            for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx;
                 token_idx += num_recv_warps - 1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
841
842
                // Read expected head
                int expected_head = -1;
843
844
                if (lane_id < kNumRanks)
                    expected_head = ld_nc_global(send_head + token_idx * kNumRanks + lane_id);
845

lijian6's avatar
lijian6 committed
846
847
                auto start_time = wall_clock64();
                while (__any(channel_tail_idx[lane_id] <= expected_head and expected_head >= 0)) {
Chenggang Zhao's avatar
Chenggang Zhao committed
848
                    // Timeout check
lijian6's avatar
lijian6 committed
849
850
851
852
                    if (wall_clock64() - start_time > NUM_TIMEOUT_CYCLES) {
                        printf("DeepEP timeout for combine receivers, rank %d, responsible_channel "
                               "= %d, expect = %d\n",
                               rank, responsible_channel, expected_head);
Chenggang Zhao's avatar
Chenggang Zhao committed
853
854
855
                        trap();
                    }
                }
lijian6's avatar
lijian6 committed
856
                syncwarp();
Chenggang Zhao's avatar
Chenggang Zhao committed
857
858
859

                // Broadcast current heads
                int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks];
lijian6's avatar
lijian6 committed
860
861
862
#pragma unroll
                for (int i = 0; i < kNumRanks; ++i) {
                    auto expected_head_i = __shfl(expected_head, i);
Chenggang Zhao's avatar
Chenggang Zhao committed
863
864
                    if (expected_head_i >= 0) {
                        slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens;
lijian6's avatar
lijian6 committed
865
                        topk_ranks[num_topk_ranks++] = i;
Chenggang Zhao's avatar
Chenggang Zhao committed
866
867
868
                    }
                }

lijian6's avatar
lijian6 committed
869
870
871
// Reduce data
#pragma unroll
                for (int i = lane_id; i < hidden_int4; i += kWarpSize) {
Shangyan Zhou's avatar
Shangyan Zhou committed
872
873
                    // Read bias
                    // TODO: make it as a template
lijian6's avatar
lijian6 committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
                    int4  bias_0_value_int4     = bias_0_int4 != nullptr
                                                      ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i)
                                                      : make_int4(0, 0, 0, 0);
                    int4  bias_1_value_int4     = bias_1_int4 != nullptr
                                                      ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i)
                                                      : make_int4(0, 0, 0, 0);
                    float values[kDtypePerInt4] = {0};
                    auto  bias_0_values = reinterpret_cast<const dtype_t *>(&bias_0_value_int4);
                    auto  bias_1_values = reinterpret_cast<const dtype_t *>(&bias_1_value_int4);
#pragma unroll
                    for (int j = 0; j < kDtypePerInt4; ++j)
                        values[j] = static_cast<float>(bias_0_values[j]) +
                                    static_cast<float>(bias_1_values[j]);
#pragma unroll
                    for (int j = 0; j < num_topk_ranks; ++j) {
                        int4 recv_value = __ldg(channel_x_buffers[topk_ranks[j]].buffer() +
                                                slot_indices[j] * hidden_int4 + i);
                        const dtype_t *recv_dtypes = reinterpret_cast<const dtype_t *>(&recv_value);

#pragma unroll
                        for (int k = 0; k < kDtypePerInt4; ++k)
                            values[k] += static_cast<float>(recv_dtypes[k]);
Chenggang Zhao's avatar
Chenggang Zhao committed
896
897
                    }

898
                    // Cast back to `dtype_t`
Chenggang Zhao's avatar
Chenggang Zhao committed
899
                    int4 out_int4;
lijian6's avatar
lijian6 committed
900
901
902
                    auto out_dtypes = reinterpret_cast<dtype_t *>(&out_int4);
#pragma unroll
                    for (int j = 0; j < kDtypePerInt4; ++j)
Chenggang Zhao's avatar
Chenggang Zhao committed
903
                        out_dtypes[j] = static_cast<dtype_t>(values[j]);
904

lijian6's avatar
lijian6 committed
905
                    recv_int4[token_idx * hidden_int4 + i] = out_int4;
Chenggang Zhao's avatar
Chenggang Zhao committed
906
907
908
                }

                // Reduce `topk_weights`
909
                if (lane_id < num_topk) {
Chenggang Zhao's avatar
Chenggang Zhao committed
910
                    float value = 0;
lijian6's avatar
lijian6 committed
911
912
913
914
#pragma unroll
                    for (int i = 0; i < num_topk_ranks; ++i)
                        value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() +
                                              slot_indices[i] * num_topk + lane_id);
915
                    recv_topk_weights[token_idx * num_topk + lane_id] = value;
Chenggang Zhao's avatar
Chenggang Zhao committed
916
                }
917
918

                // Update head
919
                if (lane_id < kNumRanks)
lijian6's avatar
lijian6 committed
920
921
                    warp_channel_head_idx[recv_warp_id][lane_id] =
                        (expected_head < 0) ? -expected_head - 1 : expected_head + 1;
Chenggang Zhao's avatar
Chenggang Zhao committed
922
923
924
            }

            // Retired
lijian6's avatar
lijian6 committed
925
926
            syncwarp();
            if (lane_id == 0)
Chenggang Zhao's avatar
Chenggang Zhao committed
927
928
929
930
931
                warp_retired[recv_warp_id] = true;
        }
    }
}

lijian6's avatar
lijian6 committed
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
void combine(hipDataType type, void *recv_x, float *recv_topk_weights, const void *x,
             const float *topk_weights, const void *bias_0, const void *bias_1, const int *src_idx,
             const int *rank_prefix_matrix, const int *channel_prefix_matrix, int *send_head,
             int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs,
             int rank, int num_ranks, hipStream_t stream, int num_sms, int num_max_send_tokens,
             int num_recv_buffer_tokens) {
    constexpr int kNumThreads = 1024;

#define COMBINE_LAUNCH_CASE(dtype, ranks)                                                          \
    {                                                                                              \
        LAUNCH_KERNEL_NON_COOPERATIVE(                                                             \
            &cfg, combine<dtype, ranks, kNumThreads>, reinterpret_cast<dtype *>(recv_x),           \
            recv_topk_weights, reinterpret_cast<const dtype *>(x), topk_weights,                   \
            reinterpret_cast<const dtype *>(bias_0), reinterpret_cast<const dtype *>(bias_1),      \
            src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, num_tokens,             \
            num_recv_tokens, hidden, num_topk, buffer_ptrs, rank, num_max_send_tokens,             \
            num_recv_buffer_tokens);                                                               \
    }                                                                                              \
    break
#define COMBINE_DTYPE_LAUNCH_CASE(dtype)                                                           \
    SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE);                                           \
Chenggang Zhao's avatar
Chenggang Zhao committed
953
954
955
956
    break

    // Even-numbered blocks for sending, odd-numbered blocks for receiving
    EP_HOST_ASSERT(num_sms % 2 == 0);
lijian6's avatar
lijian6 committed
957
    EP_HOST_ASSERT(kNumThreads >= num_ranks * kWarpSize);
Chenggang Zhao's avatar
Chenggang Zhao committed
958
959
960
961
962
963
964
965
966
    SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
    SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE);
#undef COMBINE_DTYPE_LAUNCH_CASE
#undef COMBINE_LAUNCH_CASE
}

} // namespace intranode

} // namespace deep_ep