internode_ll.cu 40.5 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
4
5
6
7
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
lishen's avatar
lishen committed
8
9
10

#include "hip/hip_runtime.h"

11
#include "shmem_wrapper.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
12
13
14
15
16

namespace deep_ep {

namespace internode_ll {

lishen's avatar
lishen committed
17
18
19
20
21
22
23
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    dtype_b_t packed;
    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
    unpacked_ptr[0] = x, unpacked_ptr[1] = y;
    return packed;
24
25
26
27
28
}

__device__ void grid_barrier(int* global_counter, int num_blocks) {
    volatile int ret;
    __syncthreads();
lishen's avatar
lishen committed
29
    __threadfence();
30
    if (threadIdx.x == 0 ) {
lishen's avatar
lishen committed
31
        ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32
33
34
    }
    __syncthreads();
    if (threadIdx.x == 0) {
lishen's avatar
lishen committed
35
        while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
36
37
38
    }
    __syncthreads();
}
lishen's avatar
lishen committed
39
40
41
42
43
44
45
46
47
48
49
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
    return (a + b - 1) / b;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
    x = unpacked_ptr[0], y = unpacked_ptr[1];
}
50
51


Chenggang Zhao's avatar
Chenggang Zhao committed
52
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
53
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
lishen's avatar
lishen committed
54
                                         int64_t* clean_1, int num_clean_int_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
55
    // Barrier before cleaning (in case of unfinished chunked EP)
lishen's avatar
lishen committed
56
    if (threadIdx.x == 0)
57
        internode::shmem_device_barrier_all();
58
59

    // Clean
lishen's avatar
lishen committed
60
    auto thread_id = static_cast<int>(threadIdx.x);
61
62
63
64
65
66
67
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
        clean_0[i] = 0;
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
        clean_1[i] = 0;

lishen's avatar
lishen committed
68
    // Barrier after cleaning (make sure low-latency mode work
lishen's avatar
lishen committed
69
    if (threadIdx.x == 0)
70
        internode::shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
71
72
}

73
74
75
76
77
78
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
                              int64_t* clean_1, int num_clean_int_1,
                              hipStream_t stream) {
    constexpr int kNumThreads = 256;

    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
lishen's avatar
lishen committed
79
80
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
                  clean_0, num_clean_int_0, clean_1, num_clean_int_1);
Chenggang Zhao's avatar
Chenggang Zhao committed
81
82
}

lishen's avatar
lishen committed
83
template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden>
lishen's avatar
lishen committed
84
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
85
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
86
87
88
89
90
91
92
93
94
         int* packed_recv_src_info, int64_t* packed_recv_layout_range,
         int* packed_recv_count,
         int* global_atomic_counter,
         void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
         const void* x, const int64_t* topk_idx,
         int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
         int64_t* next_clean, int num_next_clean_int,
         int num_tokens, int num_max_dispatch_tokens_per_rank,
         int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
95
         int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
96
         bool round_scale, int phases) {
97
98
99
100
101
102
103
104
105
106
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;

lishen's avatar
lishen committed
107
108
109
110
111
    // May extract UE8M0 from the scales
    using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
    using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
    EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");

112
113
    // FP8 staffs
    constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
lishen's avatar
lishen committed
114
    constexpr int kNumScales = kHidden / kNumPerChannels;
115
116
117
    const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
    const size_t hidden_int4 = hidden_bytes / sizeof(int4);

lishen's avatar
lishen committed
118
    // Message package: hidden data, FP8 scales, index at source
119
    // NOTES: currently we have 3 reserved int fields for future use
lishen's avatar
lishen committed
120
    using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
lishen's avatar
lishen committed
121
    const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
122
123
124
    const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
    EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

lishen's avatar
lishen committed
125
126
127
128
129
    // Expert counts
    constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
    __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

    // Sending phase
130
131
132
133
134
135
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_DISPATCH_RECV;

    // There are 2 kinds of warps in this part:
    // 1. The first-kind warps for FP8 cast and sending top-k tokens
    // 2. The last warp for reading `topk_idx` and count for per-expert information
lishen's avatar
lishen committed
136
137
138
    if (warp_id < num_warps) {
        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
        EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
139
140
        EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
        const auto num_threads = (num_warps - 1) * kWarpSize;
lishen's avatar
lishen committed
141
        constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
142
143

        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
lishen's avatar
lishen committed
144
145
            const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
            const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
146
147
148
            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);

lishen's avatar
lishen committed
149
            // Overlap top-k index read and source token index write
150
151
152
            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

lishen's avatar
lishen committed
153
154
155
156
157
158
159
160
            __shared__ float int8_amaxf[kNumScales];
            if constexpr(kUseInt8) {
                if (thread_id < kNumScales) {
                    int8_amaxf[thread_id] = kFP8Margin;
                }
                __syncthreads();
            }

161
162
163
164
165
166
            // FP8 cast
            #pragma unroll
            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                // Read
                auto int4_value = __ldg(x_int4 + i);

lishen's avatar
lishen committed
167
                if constexpr(kUseFP8) {
168
169
170
171
172
                    // Calculate local amax
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
                    float fp32_values[kNumElemsPerRead];
                    float amax = kFP8Margin, scale, scale_inv;
                    #pragma unroll
lishen's avatar
lishen committed
173
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
174
175
176
177
178
179
                        fp32_values[j] = static_cast<float>(bf16_values[j]);
                        amax = fmaxf(amax, fabsf(fp32_values[j]));
                    }
                    // Reduce amax and scale
                    EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
                    amax = warp_reduce_max<16>(amax);
lishen's avatar
lishen committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
                    const int scale_offset = i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL;

                    if constexpr(kUseInt8) {
                        // 记录每128个数的最大值
                        int8_amaxf[scale_offset] = fmaxf(amax, int8_amaxf[scale_offset]);
                    } else {
                        calculate_fp8_scales(amax, scale, scale_inv, round_scale);
                        if (lane_id % 16 == 0)
                            rdma_x_scales[scale_offset] = scale_inv;

                        // Cast into send buffer
                        vec_t int2_value;
                        auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
                        #pragma unroll
                        for (int j = 0; j < kNumElemsPerRead; j += 2) {
                            float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
                            fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
                        }
                        rdma_x_vec[i] = int2_value;
199
200
201
202
203
204
205
                    }
                } else {
                    // Reinterpret-cast is for C++14 compatibility
                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
                }
            }
            __syncthreads();
lishen's avatar
lishen committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

            if constexpr(kUseInt8) {
                float amax_per_token = kFP8Margin;
                // 并行规约,计算每个token的amax
                for (int s = 0; s < kNumScales; s+=kWarpSize) {
                    int src_idx = s + lane_id;
                    float tmp_amaxf = 0;
                    if(src_idx < kNumScales) {
                        tmp_amaxf = int8_amaxf[src_idx];
                    }
                    tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
                    int8_amaxf[0] = fmaxf(tmp_amaxf, int8_amaxf[0]);
                    __syncthreads();
                }
                amax_per_token = int8_amaxf[0];

                // 根据最大值计算scale
                float scale, scale_inv;
                calculate_int8_scales(amax_per_token, scale, scale_inv);
                if (thread_id == 0) {
                    rdma_x_scales[0] = scale_inv;
                }

                for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                    // Read
                    auto int4_value = __ldg(x_int4 + i);
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);

                    // Cast into send buffer
                    vec_t int2_value;
                    auto int8_values = reinterpret_cast<int8_t*>(&int2_value);
#pragma unroll
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
                        auto fp32_value = static_cast<float>(bf16_values[j]);
                        auto fp32_value_scaled = fp32_value * scale;
                        int8_values[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
                    }
                    rdma_x_vec[i] = int2_value;
                }
                __syncthreads();
            }

248
249
250
251
            // Issue IBGDA sends
            if (dst_expert_idx >= 0) {
                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
                slot_idx = shfl_sync(slot_idx, 0);
lishen's avatar
lishen committed
252
253
                const auto dst_rank = dst_expert_idx / num_local_experts;
                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
254
255
256
                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
                                     dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
lishen's avatar
lishen committed
257
258
                                     rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                     slot_idx * num_bytes_per_msg;
259
                if (dst_rank != rank) {
lijian6's avatar
lijian6 committed
260
261
#if defined(FORCE_DUSHMEM_API)
                    void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
lijian6's avatar
lijian6 committed
262
                    if (peer_base_addr) {
lijian6's avatar
lijian6 committed
263
                        char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
lijian6's avatar
lijian6 committed
264
265
266
                        const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
                        const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
                        UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
267
268
269
270
271
                    } else {
                        internode::shmemx_int8_put_nbi_warp(
                            reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
                            num_bytes_per_msg, dst_rank);
                    }
lishen's avatar
lishen committed
272
#else
273
                    #if !defined(ROCM_USE_MULTIQP)
lishen's avatar
lishen committed
274
275
276
                    internode::shmemx_int8_put_nbi_warp(
                        reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
                        num_bytes_per_msg, dst_rank);
277
278
279
280
281
                    #else
                    internode::shmemx_int8_put_nbi_warp_dp(
                        reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
                        num_bytes_per_msg, (dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
                    #endif
lijian6's avatar
lijian6 committed
282
#endif // defined(FORCE_DUSHMEM_API)
283
284
285
286
                } else {
                    // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                    const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
                    const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
lishen's avatar
lishen committed
287
                    UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
288
                }
lishen's avatar
lishen committed
289

290
291
292
293
294
                // Increase counter after finishing
                syncwarp();
                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
            }
        }
lishen's avatar
lishen committed
295
296
    }
    if (warp_id == num_warps - 1) {
297
298
        EP_DEVICE_ASSERT(num_sms > 1);
        if (sm_id == 0) {
lishen's avatar
lishen committed
299
            // The first SM is also responsible for checking QPs
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            // The first SM is also responsible for cleaning the next buffer
            #pragma unroll
            for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
                next_clean[i] = 0;

            // Notify before executing `int_p`
            syncwarp();
            #pragma unroll
            for (int i = lane_id; i < num_experts; i += kWarpSize)
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
        }
        // This SM should be responsible for some destination experts, read `topk_idx` for them
        int expert_count[kNumMaxWarpGroups] = {0};
        const auto expert_begin_idx = sm_id * num_warp_groups;
        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

        // Per lane count
        #pragma unroll 8
        for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
            auto idx = static_cast<int>(__ldg(topk_idx + i));
            if (idx >= expert_begin_idx and idx < expert_end_idx)
lishen's avatar
lishen committed
321
                expert_count[idx - expert_begin_idx] ++;
322
323
324
325
        }

        // Warp reduce
        #pragma unroll
lishen's avatar
lishen committed
326
        for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
            if (lane_id == 0) {
                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
            }
        }
    }
    __syncthreads();

    // Issue count sends
    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];

        // Wait local sends issued and send expert counts
        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
lishen's avatar
lishen committed
344
        if (dst_rank != rank) {
lijian6's avatar
lijian6 committed
345
346
#if defined(FORCE_DUSHMEM_API)
            void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
lijian6's avatar
lijian6 committed
347
            if (peer_base_addr) {   // P2P enabled
348
                int *rptr_actual = (int *)((char *)(peer_base_addr) +
lijian6's avatar
lijian6 committed
349
                    ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(dushmemi_device_state_d.heap_base)));
lijian6's avatar
lijian6 committed
350
                st_na_release(rptr_actual, -num_tokens_sent - 1);
351
352
353
354
            } else {
                internode::shmem_long_atomic_add(
                    rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
            }
lishen's avatar
lishen committed
355
#else
356
            #if !defined(ROCM_USE_MULTIQP)
lishen's avatar
lishen committed
357
358
            internode::shmem_long_atomic_add(
                rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
359
360
361
362
363
            #else
            internode::shmem_long_atomic_add_dp(
                rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
                (dst_expert_local_idx + 1) * num_ranks + dst_rank, dst_rank);
            #endif
lijian6's avatar
lijian6 committed
364
#endif // defined(FORCE_DUSHMEM_API)
lishen's avatar
lishen committed
365
366
        } else {
            st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
367
368
369
370
371
372
373
374
375
376
377
378
        }

        // Clean workspace for next use
        atomic_counter_per_expert[responsible_expert_idx] = 0;
        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

        // Clean `packed_recv_count`
        if (dst_rank == 0)
            packed_recv_count[dst_expert_local_idx] = 0;
    }
    syncwarp();

lishen's avatar
lishen committed
379
380
    // Receiving phase
LOW_LATENCY_DISPATCH_RECV:
381
382
383
384
385
386
387
388
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;

    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
    if (phases & LOW_LATENCY_SEND_PHASE){
        grid_barrier(global_atomic_counter, num_sms);
    }

lishen's avatar
lishen committed
389
390
391
392
393
394
395
396
397
398
399
    // 16 is the max possible number of warps in AMD GPUs
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    constexpr int num_sync_large_iteration = kMaxNumWarps ;
    __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];

#pragma unroll
    for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

400
401
402
403
    // Receiving and packing
    if (responsible_expert_idx < num_experts) {
        const auto src_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
lishen's avatar
lishen committed
404
        const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
lishen's avatar
lishen committed
405
406
                                       local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                       src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
lishen's avatar
lishen committed
407
        const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
lishen's avatar
lishen committed
408
                                 local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
409
410
        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
lishen's avatar
lishen committed
411
        const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
lishen's avatar
lishen committed
412
        const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
lishen's avatar
lishen committed
413
414
                                   local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
                                       (kUseInt8 ? 1 : num_aligned_scales);
415
416
417
418
419
420

        // Shared between sub-warps in warp groups
        __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];

        // Wait tokens to arrive
        // NOTES: using sub-warp 1 to overlap with sub-warp 0
lishen's avatar
lishen committed
421
422
        int num_recv_tokens, recv_token_begin_idx;
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
423
424

        if (sub_warp_id == 1 and lane_id == 0) {
lishen's avatar
lishen committed
425
            while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
426
            num_recv_tokens = -num_recv_tokens - 1;
lishen's avatar
lishen committed
427
428
            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
429
            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
lishen's avatar
lishen committed
430
            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
431
432
        }

433
434
435
436
437
438
#if defined(ROCM_USE_MULTIQP)
        if (sub_warp_id == 2 and lane_id == 0) {
            internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
        }
#endif

439
440
        // no needs to reset because there is no iteration
        if (lane_id == 0){
lishen's avatar
lishen committed
441
            volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
442
443
444
        }
        syncwarp();

lishen's avatar
lishen committed
445
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
446
447
448
449
        num_recv_tokens = shared_num_recv_tokens[warp_group_id];
        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];

        // Copy tokens
lishen's avatar
lishen committed
450
        EP_DEVICE_ASSERT(kNumScales <= 64);
451
452
453
454
455
456
457
458
459
460
461
        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
            // Copy source info
            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
            if (lane_id == 0)
                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
            syncwarp();

            // Copy data
            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
lishen's avatar
lishen committed
462
            UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
463
464

            // Copy scales
lishen's avatar
lishen committed
465
            if constexpr(kUseFP8) {
466
                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
lishen's avatar
lishen committed
467
468
469
470
471
                const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
                const auto token_idx = recv_token_begin_idx + i;
                const auto token_stride = num_elems_per_pack;
                const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;

lishen's avatar
lishen committed
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                if constexpr(kUseInt8) {
                    if (lane_id == 0) {
                        recv_x_scales[token_idx] = ld_nc_global(src_scales);
                    }
                } else {
                    if (lane_id < kNumScales) {
                        const auto pack_idx = lane_id / num_elems_per_pack;
                        const auto elem_idx = lane_id % num_elems_per_pack;
                        auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
                        recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                    }
                    if (lane_id + kWarpSize < kNumScales) {
                        const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
                        const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
                        auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
                        recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                    }
lishen's avatar
lishen committed
489
                }
490
491
492
            }
        }
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
493
494
}

lishen's avatar
lishen committed
495
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
496
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
497
              int* packed_recv_count,
498
              int* global_atomic_counter,
lishen's avatar
lishen committed
499
500
501
502
              void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int64_t* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
lishen's avatar
lishen committed
503
              int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
504
              bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
lishen's avatar
lishen committed
505
              void* workspace, int num_device_sms,
lishen's avatar
lishen committed
506
              hipStream_t stream, int phases) {
507
    constexpr int kNumMaxTopK = 11;
lishen's avatar
lishen committed
508
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
lishen's avatar
lishen committed
509
    const int num_warps_per_group = 16 / num_warp_groups;
510
511
512
513
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);

    const auto num_warps = num_warp_groups * num_warps_per_group;
lishen's avatar
lishen committed
514
    const auto num_sms = ceil_div(num_experts, num_warp_groups);
515
516
517
    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);

    // Workspace checks
lishen's avatar
lishen committed
518
    auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
519
520
521
    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);

lishen's avatar
lishen committed
522
#define DISPATCH_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
523
524
525
526
527
528
529
auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0)             \
    dispatch_func = dispatch<true, false, false, hidden>;   \
if (use_fp8 and use_ue8m0)             \
    dispatch_func = dispatch<true, true, false, hidden>;    \
if (use_int8)             \
    dispatch_func = dispatch<true, false, true, hidden>;    \
lishen's avatar
lishen committed
530
531
532
533
534
535
536
537
538
539
540
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
              packed_recv_x, packed_recv_x_scales, \
              packed_recv_src_info, packed_recv_layout_range, \
              packed_recv_count, \
              global_atomic_counter, \
              rdma_recv_x, rdma_recv_count, rdma_x, \
              x, topk_idx, \
              atomic_counter_per_expert, atomic_finish_counter_per_expert, \
              next_clean, num_next_clean_int, \
              num_tokens, num_max_dispatch_tokens_per_rank, \
              num_topk, num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
541
              num_warp_groups, num_warps_per_group, round_scale, phases); } break
542
543
544
545

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
546
547
}

lishen's avatar
lishen committed
548
549
template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
550
551
552
553
554
combine(void* combined_x,
        void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
        const void* x, const int64_t* topk_idx, const float* topk_weights,
        const int* src_info, const int64_t* layout_range,
        int* global_atomic_counter,
lishen's avatar
lishen committed
555
        int64_t* combine_wait_recv_cost_stats,
lishen's avatar
lishen committed
556
557
558
559
560
        int64_t* next_clean, int num_next_clean_int,
        int* atomic_clean_flag,
        int num_combined_tokens, int hidden, int num_topk,
        int num_max_dispatch_tokens_per_rank,
        int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
561
        int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
562
563
564
565
566
567
568
        int phases, bool zero_copy) {
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_local_experts = num_experts / num_ranks;
lishen's avatar
lishen committed
569
570
571
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
lishen's avatar
lishen committed
572
573
574
575
576
577

    // Data type staffs
    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
    const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;

    // Message package
lishen's avatar
lishen committed
578
    EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
lishen's avatar
lishen committed
579
    constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
lishen's avatar
lishen committed
580
    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
lishen's avatar
lishen committed
581

lishen's avatar
lishen committed
582
    // 16 is the max possible number of warps in AMD GPUs
lishen's avatar
lishen committed
583
584
585
586
587
588
589
590
591
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
    if (threadIdx.x==0){
        #pragma unroll
        for (int i = 0; i < kMaxNumWarps; ++i) {
            sync_large_warp_counters[i] = 0;
        }
    }
    __syncthreads();
592

lishen's avatar
lishen committed
593
594
595
    // Sending phase
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_COMBINE_RECV;
Chenggang Zhao's avatar
Chenggang Zhao committed
596

lishen's avatar
lishen committed
597
598
599
600
601
    // Clean up next buffer
    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
        #pragma unroll
        for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
            next_clean[i] = 0;
602

lishen's avatar
lishen committed
603
604
605
606
607
        // Notify before executing `int_p`
        syncwarp();
        if (lane_id == 0)
            atomic_add_release_global(atomic_clean_flag, num_experts);
    }
608

lishen's avatar
lishen committed
609
610
611
612
613
614
615
    // Issue IBGDA sends
    if (responsible_expert_idx < num_experts) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
        const auto local_x = reinterpret_cast<const int4*>(x) +
lishen's avatar
lishen committed
616
                             local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
lishen's avatar
lishen committed
617
618
        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
lishen's avatar
lishen committed
619
                                     local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
lishen's avatar
lishen committed
620
621
622
623
624
625

        // Unpack layout
        int offset, num_tokens_to_send;
        unpack2(layout, num_tokens_to_send, offset);

        // Issue IBGDA send
lishen's avatar
lishen committed
626
        for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
lishen's avatar
lishen committed
627
628
            const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
            const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
lishen's avatar
lishen committed
629
            const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
lishen's avatar
lishen committed
630
631

            // Copy directly to local rank, or copy to buffer and issue RDMA
632
            const auto src_idx = __ldg(local_src_info + token_idx);
lishen's avatar
lishen committed
633
            const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
lishen's avatar
lishen committed
634
            const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
lishen's avatar
lishen committed
635
636
637
638
639
640
641
            if (dst_rank == rank) {
                const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
                UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
            } else {
                const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
                if (not zero_copy)
                    UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
642

lijian6's avatar
lijian6 committed
643
644
#if defined(FORCE_DUSHMEM_API)
                void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
lijian6's avatar
lijian6 committed
645
                if (peer_base_addr) {
lijian6's avatar
lijian6 committed
646
                    char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(dushmemi_device_state_d.heap_base));
lijian6's avatar
lijian6 committed
647
648
                    const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
                    UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
649
650
651
652
653
                } else {
                    internode::shmemx_int8_put_nbi_warp(
                        reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
                        hidden * sizeof(hip_bfloat16), dst_rank);
                }
lishen's avatar
lishen committed
654
#else
655
                #if !defined(ROCM_USE_MULTIQP)
lishen's avatar
lishen committed
656
657
658
                internode::shmemx_int8_put_nbi_warp(
                    reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
                    hidden * sizeof(hip_bfloat16), dst_rank);
659
660
661
662
663
                #else
                internode::shmemx_int8_put_nbi_warp_dp(
                    reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
                    hidden * sizeof(hip_bfloat16), (local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
                #endif
lijian6's avatar
lijian6 committed
664
#endif // defined(FORCE_DUSHMEM_API)
lishen's avatar
lishen committed
665
            }
lishen's avatar
lishen committed
666
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
667

lishen's avatar
lishen committed
668
        // Put finishing flag
lishen's avatar
lishen committed
669
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
670
        if (lane_id == 0){
lishen's avatar
lishen committed
671
            volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
lishen's avatar
lishen committed
672
673
        }
        syncwarp();
lishen's avatar
lishen committed
674
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
lishen's avatar
lishen committed
675

lishen's avatar
lishen committed
676
677
678
        if (sub_warp_id == 1 and lane_id == 0) {
            while (ld_acquire_global(atomic_clean_flag) == 0);
            if (dst_rank != rank) {
lijian6's avatar
lijian6 committed
679
680
#if defined(FORCE_DUSHMEM_API)
                void *peer_base_addr = (void *)__ldg((const long long unsigned *)dushmemi_device_state_d.peer_heap_base_p2p + dst_rank);
lijian6's avatar
lijian6 committed
681
                if (peer_base_addr) {
682
                    int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
lijian6's avatar
lijian6 committed
683
                        ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(dushmemi_device_state_d.heap_base)));
lijian6's avatar
lijian6 committed
684
                    st_na_release(req_rptr_actual, 1);
685
686
687
688
                } else {
                    internode::shmem_long_atomic_add(
                        rdma_recv_flag + global_expert_idx, 1, dst_rank);
                }
lishen's avatar
lishen committed
689
#else
690
                #if !defined(ROCM_USE_MULTIQP)
lishen's avatar
lishen committed
691
692
                internode::shmem_long_atomic_add(
                    rdma_recv_flag + global_expert_idx, 1, dst_rank);
693
694
695
696
697
                #else
                internode::shmem_long_atomic_add_dp(
                    rdma_recv_flag + global_expert_idx, 1,
                    (local_expert_idx + 1) * num_ranks + dst_rank, dst_rank);
                #endif
lijian6's avatar
lijian6 committed
698
#endif // defined(FORCE_DUSHMEM_API)
lishen's avatar
lishen committed
699
700
701
702
703
704
            } else {
                st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
            }
            atomic_add_release_global(atomic_clean_flag, -1);
        }
        syncwarp();
705
706
    }

lishen's avatar
lishen committed
707
    // Receiving phase
lishen's avatar
lishen committed
708
LOW_LATENCY_COMBINE_RECV:
lishen's avatar
lishen committed
709
710
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;
711

lishen's avatar
lishen committed
712
713
    // Wait all ranks to arrive and notify PCIe usage
    if (responsible_expert_idx < num_experts) {
lishen's avatar
lishen committed
714
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        if (sub_warp_id == 0 and lane_id == 0) {
            const auto src_rank = responsible_expert_idx / num_local_experts;
            auto start_time = wall_clock64();
            uint64_t wait_recv_cost = 0;
            while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0  // recv not ready
                   && (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES   // not timeout
            );

            // Mask rank if timeout
            if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
                printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
                       rank, responsible_expert_idx % num_local_experts, src_rank);
            }

            if (combine_wait_recv_cost_stats != nullptr) {
                atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
            }
lishen's avatar
lishen committed
732
        }
733
734
735
736
737
#if defined(ROCM_USE_MULTIQP)
        if (sub_warp_id == 2 and lane_id == 0) {
            internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
        }
#endif
738
    }
lishen's avatar
lishen committed
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    grid_barrier(global_atomic_counter, num_sms);

    // Reduce tokens with FP8 cast
    EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
    EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
    if (thread_id < hidden_bf16_int4) {
        for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
            // Read top-k indices and weights
            int reg_topk_idx[kNumMaxTopk];
            float reg_topk_weights[kNumMaxTopk];
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) {
                reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
                reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
            }
754

lishen's avatar
lishen committed
755
756
757
758
            float combined_values[kNumElemsPerInt4] = {0.0f};
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
                // Read from sources
759
760
                auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
                    (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
lishen's avatar
lishen committed
761
                auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
lishen's avatar
lishen committed
762
763
764
765
766
767
768
769

                // Reduce
                auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
                const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
                #pragma unroll
                for (int j = 0; j < kNumElemsPerInt4; ++ j)
                    combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
            }
770

lishen's avatar
lishen committed
771
772
773
774
775
776
777
778
779
            // Write results
            int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
            auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
            #pragma unroll
            for (int j = 0; j < kNumElemsPerInt4; ++ j)
                combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
            (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
        }
    }
780
781
}

lishen's avatar
lishen committed
782
783
784
785
786
void combine(void* combined_x,
             void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* global_atomic_counter,
lishen's avatar
lishen committed
787
             int64_t* combine_wait_recv_cost_stats,
lishen's avatar
lishen committed
788
789
790
             int64_t* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
791
             void* workspace, int num_device_sms, hipStream_t stream,
lishen's avatar
lishen committed
792
             int phases, bool zero_copy) {
lishen's avatar
lishen committed
793
794
795
796
797
    constexpr int kNumMaxTopk = 11;
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
    const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
    const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
lishen's avatar
lishen committed
798

lishen's avatar
lishen committed
799
800
801
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_sms =
        max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
lishen's avatar
lishen committed
802
803
804
805
806
807
808

    // Check workspace
    auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);

#define COMBINE_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
809
auto combine_func = combine<hidden, kNumMaxTopk>; \
lishen's avatar
lishen committed
810
811
812
813
814
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
              combined_x, \
              rdma_recv_x, rdma_recv_flag, rdma_send_x, \
              x, topk_idx, topk_weights, src_info, layout_range, \
              global_atomic_counter, \
lishen's avatar
lishen committed
815
              combine_wait_recv_cost_stats, \
lishen's avatar
lishen committed
816
817
818
819
820
              next_clean, num_next_clean_int, \
              atomic_clean_flag, \
              num_combined_tokens, hidden, num_topk, \
              num_max_dispatch_tokens_per_rank, \
              num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
821
              num_warp_groups, num_warps_per_group, phases, zero_copy); } break
lishen's avatar
lishen committed
822
823
824
825

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
826
827
}

Chenggang Zhao's avatar
Chenggang Zhao committed
828
829
830
} // namespace internode_ll

} // namespace deep_ep