internode_ll.cu 42.7 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
4
5
6
7
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
lishen's avatar
lishen committed
8
9
10

#include "hip/hip_runtime.h"

11
#include "shmem_wrapper.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
12
13
14
15
16

namespace deep_ep {

namespace internode_ll {

lishen's avatar
lishen committed
17
18
19
20
21
22
23
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    dtype_b_t packed;
    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
    unpacked_ptr[0] = x, unpacked_ptr[1] = y;
    return packed;
24
25
26
27
28
}

__device__ void grid_barrier(int* global_counter, int num_blocks) {
    volatile int ret;
    __syncthreads();
lishen's avatar
lishen committed
29
    __threadfence();
30
    if (threadIdx.x == 0 ) {
lishen's avatar
lishen committed
31
        ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
32
33
34
    }
    __syncthreads();
    if (threadIdx.x == 0) {
lishen's avatar
lishen committed
35
        while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
36
37
38
    }
    __syncthreads();
}
lishen's avatar
lishen committed
39
40
41
42
43
44
45
46
47
48
49
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
    return (a + b - 1) / b;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
    x = unpacked_ptr[0], y = unpacked_ptr[1];
}
50
51


Chenggang Zhao's avatar
Chenggang Zhao committed
52
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
53
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
lishen's avatar
lishen committed
54
                                         int64_t* clean_1, int num_clean_int_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
55
    // Barrier before cleaning (in case of unfinished chunked EP)
lishen's avatar
lishen committed
56
    if (threadIdx.x == 0)
57
        internode::shmem_device_barrier_all();
58
59

    // Clean
lishen's avatar
lishen committed
60
    auto thread_id = static_cast<int>(threadIdx.x);
61
62
63
64
65
66
67
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
        clean_0[i] = 0;
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
        clean_1[i] = 0;

lishen's avatar
lishen committed
68
    // Barrier after cleaning (make sure low-latency mode work
lishen's avatar
lishen committed
69
    if (threadIdx.x == 0)
70
        internode::shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
71
72
}

73
74
75
76
77
78
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
                              int64_t* clean_1, int num_clean_int_1,
                              hipStream_t stream) {
    constexpr int kNumThreads = 256;

    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
lishen's avatar
lishen committed
79
80
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
                  clean_0, num_clean_int_0, clean_1, num_clean_int_1);
Chenggang Zhao's avatar
Chenggang Zhao committed
81
82
}

lishen's avatar
lishen committed
83
84
85
86
__device__ __forceinline__ void 
internode_ll_putmem_nbi(void* dst_ptr, void* src_ptr,
                        int num_ranks, int dst_rank, int expert_idx,
                        int msg_bytes) {
lishen's avatar
fix  
lishen committed
87
#if defined(FORCE_DUSHMEM_API)
lishen's avatar
lishen committed
88
89
90
91
92
93
94
95
96
97
98
99
100
        internode::shmemx_int8_put_nbi_warp(
            reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
            msg_bytes, dst_rank);
#else
    #if defined(ROCM_DISABLE_MULTIQP)
        internode::shmemx_int8_put_nbi_warp(
            reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
            msg_bytes, dst_rank);
    #else
        internode::shmemx_int8_put_nbi_warp_dp(
            reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
            msg_bytes, (expert_idx + 1) * num_ranks + dst_rank, dst_rank);
    #endif
lishen's avatar
fix  
lishen committed
101
#endif // defined(FORCE_DUSHMEM_API)
lishen's avatar
lishen committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
}

__device__ __forceinline__ void 
internode_ll_long_atomic_add(long* dest, const long &value, 
                             int num_ranks, int dst_rank, int expert_idx) {
#if defined(FORCE_DUSHMEM_API)
        internode::shmem_long_atomic_add(dest, value, dst_rank);
#else
        #if defined(ROCM_DISABLE_MULTIQP)
        internode::shmem_long_atomic_add(dest, value, dst_rank);
        #else
        internode::shmem_long_atomic_add_dp(dest, value,
            (expert_idx + 1) * num_ranks + dst_rank, dst_rank);
        #endif
#endif // defined(FORCE_DUSHMEM_API)
}

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/**
 * @brief 将 K 个浮点数(BF16/FP32)量化并打包成 INT2(64位)存储
 * 
 * @tparam kQuantType 量化类型 (1: Int8, 2/3: FP8_E4M3/UE8M0, 4: FP8_E5M2)
 * @tparam kNumElemsPerRead 每次读取的元素数量 (通常为 2, 4, 8)
 * @tparam SrcT 源数据类型 (float 或 __hip_bfloat16)
 * @tparam DstT 目标数据类型 (int2 或 int4)
 * @param src_values 源数据数组 (长度 >= kNumElemsPerRead)
 * @param scale 缩放因子 (将 FP32 值映射到量化范围)
 * @param[out] dst_vec 输出的 64 位向量 (int2 或 int4)
 */
template <int kQuantType, int kNumElemsPerRead, typename SrcT, typename DstT>
__forceinline__ __device__ void pack_quantized_values(
    const SrcT* src_values, float scale, DstT& dst_vec) {

    if constexpr (kQuantType == 1) {
        // INT8 量化
        auto int8_ptr = reinterpret_cast<int8_t*>(&dst_vec);
        #pragma unroll
        for (int j = 0; j < kNumElemsPerRead; ++j) {
            // 如果源是 bfloat16,先提升为 float
            float fp32_value_scaled = static_cast<float>(src_values[j]) * scale;
            // 使用 nearbyintf 进行四舍五入
            int8_ptr[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
        }
    } else {
        // FP8 量化 (E4M3, UE8M0, E5M2)
        // 假设 dst_vec 能容纳 kNumElemsPerRead/2 个 fp8x2 元素
        auto fp8x2_ptr = reinterpret_cast<__hip_fp8x2_storage_t*>(&dst_vec);
        #pragma unroll
        for (int j = 0; j < kNumElemsPerRead; j += 2) {
            // 处理两个元素
            float2 fp32x2 = {static_cast<float>(src_values[j]) * scale, static_cast<float>(src_values[j + 1]) * scale};

            if constexpr (kQuantType == 4) {
                // FP8 E5M2
lishen's avatar
lishen committed
155
                fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E5M2);
156
157
            } else {
                // FP8 E4M3 或 UE8M0
lishen's avatar
lishen committed
158
                fp8x2_ptr[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3);
159
160
161
162
163
164
            }
        }
    }
}

template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
lishen's avatar
lishen committed
165
__global__ __launch_bounds__(16 * kWarpSize, 1) void
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    dispatch(void* packed_recv_x, void* packed_recv_x_scales,
             int* packed_recv_src_info, int64_t* packed_recv_layout_range,
             int* packed_recv_count,
             int* global_atomic_counter,
             void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
             const void* x, const int64_t* topk_idx,
             int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
             int64_t* next_clean, int num_next_clean_int,
             int num_tokens, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
             int num_warp_groups, int num_warps_per_group,
             bool fp8_round_scale, int phases) {
    // 定义量化类型的枚举
    enum class QuantType {
        None        = 0,        // 不进行量化
        Int8        = 1,        // 采用 Int8 量化
lishen's avatar
lishen committed
182
        FP8_E4M3    = 2,        // 采用 FP8 量化 __HIP_E4M3
183
        FP8_UE8M0   = 3,        // 采用 FP8 量化 DeepseekV3.1的 UE8M0
lishen's avatar
lishen committed
184
        FP8_E5M2    = 4         // 采用 FP8 量化 __HIP_E5M2
185
186
    };

187
188
189
190
191
192
193
194
195
196
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;

lishen's avatar
lishen committed
197
    // May extract UE8M0 from the scales
198
199
    constexpr bool kUseQuant8Bit = kQuantType > 0;
    constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
lishen's avatar
lishen committed
200
201
202
203
    using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
    using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
    EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");

204
    // FP8 staffs
205
    constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
lishen's avatar
lishen committed
206
    constexpr int kNumScales = kHidden / kNumPerChannels;
207
    const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
208
209
    const size_t hidden_int4 = hidden_bytes / sizeof(int4);

lishen's avatar
lishen committed
210
    // Message package: hidden data, FP8 scales, index at source
211
    // NOTES: currently we have 3 reserved int fields for future use
212
    using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
lishen's avatar
lishen committed
213
214
    constexpr size_t num_bytes_per_msg = sizeof(int4) + 
        (kUseQuant8Bit ? (kHidden + (kQuantGroupSize == 0 ? 4 : kNumScales) * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
lishen's avatar
lishen committed
215
216
    EP_STATIC_ASSERT(num_bytes_per_msg % sizeof(int4) == 0, "Invalid message size");
    constexpr size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
217

lishen's avatar
lishen committed
218
    // Expert counts
lishen's avatar
lishen committed
219
    __shared__ int shared_num_tokens_sent_per_expert[kMaxNumWarps];
lishen's avatar
lishen committed
220
221

    // Sending phase
222
223
224
225
226
227
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_DISPATCH_RECV;

    // There are 2 kinds of warps in this part:
    // 1. The first-kind warps for FP8 cast and sending top-k tokens
    // 2. The last warp for reading `topk_idx` and count for per-expert information
lishen's avatar
lishen committed
228
229
    if (warp_id < num_warps) {
        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
230
        constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
lishen's avatar
lishen committed
231
        // EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
232
        EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
lishen's avatar
lishen committed
233
        const auto num_threads = num_warps * kWarpSize;
lishen's avatar
lishen committed
234
        constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
235
236

        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
lishen's avatar
lishen committed
237
238
            const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
            const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
239
240
241
            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);

lishen's avatar
lishen committed
242
            // Overlap top-k index read and source token index write
243
244
245
            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

246
247
248
            // 用于记录per-channel量化的amax
            __shared__ float channel_amaxf[kNumScales];
            if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
lishen's avatar
lishen committed
249
                if (thread_id < kNumScales) {
lishen's avatar
lishen committed
250
                    channel_amaxf[thread_id] = 0.0;
lishen's avatar
lishen committed
251
252
253
254
                }
                __syncthreads();
            }

255
256
257
258
259
260
            // FP8 cast
            #pragma unroll
            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                // Read
                auto int4_value = __ldg(x_int4 + i);

261
                if constexpr(kUseQuant8Bit) {
262
263
264
                    // Calculate local amax
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
                    float fp32_values[kNumElemsPerRead];
lishen's avatar
lishen committed
265
                    float amax = 0.0, scale, scale_inv;
266
                    #pragma unroll
lishen's avatar
lishen committed
267
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
268
269
270
271
272
                        fp32_values[j] = static_cast<float>(bf16_values[j]);
                        amax = fmaxf(amax, fabsf(fp32_values[j]));
                    }
                    // Reduce amax and scale
                    EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
273
274
                    amax = warp_reduce_max<kNumThreadPerGroup>(amax);
                    const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
lishen's avatar
lishen committed
275

276
                    if constexpr(kQuantGroupSize == 0) {
lishen's avatar
lishen committed
277
                        // 记录每128个数的最大值
278
                        channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
lishen's avatar
lishen committed
279
                    } else {
280
281
                        calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
                        if (lane_id % kNumThreadPerGroup == 0)
lishen's avatar
lishen committed
282
283
284
285
                            rdma_x_scales[scale_offset] = scale_inv;

                        // Cast into send buffer
                        vec_t int2_value;
286
                        pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
lishen's avatar
lishen committed
287
                        rdma_x_vec[i] = int2_value;
288
289
290
291
292
293
294
                    }
                } else {
                    // Reinterpret-cast is for C++14 compatibility
                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
                }
            }
            __syncthreads();
lishen's avatar
lishen committed
295

296
            if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
lishen's avatar
lishen committed
297
                float amax_per_token = 0.0;
lishen's avatar
lishen committed
298
299
300
301
302
                // 并行规约,计算每个token的amax
                for (int s = 0; s < kNumScales; s+=kWarpSize) {
                    int src_idx = s + lane_id;
                    float tmp_amaxf = 0;
                    if(src_idx < kNumScales) {
303
                        tmp_amaxf = channel_amaxf[src_idx];
lishen's avatar
lishen committed
304
305
                    }
                    tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
306
                    channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
lishen's avatar
lishen committed
307
308
                    __syncthreads();
                }
309
                amax_per_token = channel_amaxf[0];
lishen's avatar
lishen committed
310
311
312

                // 根据最大值计算scale
                float scale, scale_inv;
lishen's avatar
lishen committed
313
                calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
lishen's avatar
lishen committed
314
315
316
317
318
319
320
321
322
323
324
                if (thread_id == 0) {
                    rdma_x_scales[0] = scale_inv;
                }

                for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                    // Read
                    auto int4_value = __ldg(x_int4 + i);
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);

                    // Cast into send buffer
                    vec_t int2_value;
325
                    pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
lishen's avatar
lishen committed
326
327
328
329
330
                    rdma_x_vec[i] = int2_value;
                }
                __syncthreads();
            }

331
332
333
334
            // Issue IBGDA sends
            if (dst_expert_idx >= 0) {
                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
                slot_idx = shfl_sync(slot_idx, 0);
lishen's avatar
lishen committed
335
336
                const auto dst_rank = dst_expert_idx / num_local_experts;
                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
337
338
339
                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
                                     dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
lishen's avatar
lishen committed
340
341
                                     rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                     slot_idx * num_bytes_per_msg;
lishen's avatar
lishen committed
342
343
344
345
346

                // 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
                uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
                if (p2p_ptr == 0) {  // RDMA
                    internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
lishen's avatar
lishen committed
347
348
                                            num_ranks, dst_rank, dst_expert_local_idx,
                                            num_bytes_per_msg);
lishen's avatar
lishen committed
349
                } else { //  本地 GPU 和 同一计算节点的 其他 GPU 地址
350
351
                    // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                    const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
lishen's avatar
lishen committed
352
                    const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
lishen's avatar
lishen committed
353
                    UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
354
                }
lishen's avatar
lishen committed
355

356
357
358
359
360
                // Increase counter after finishing
                syncwarp();
                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
            }
        }
lishen's avatar
lishen committed
361
362
    }
    if (warp_id == num_warps - 1) {
lishen's avatar
lishen committed
363
        // EP_DEVICE_ASSERT(num_sms > 1);
364
        if (sm_id == 0) {
lishen's avatar
lishen committed
365
            // The first SM is also responsible for checking QPs
366
367
368
369
370
371
372
373
374
375
376
377
            // The first SM is also responsible for cleaning the next buffer
            #pragma unroll
            for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
                next_clean[i] = 0;

            // Notify before executing `int_p`
            syncwarp();
            #pragma unroll
            for (int i = lane_id; i < num_experts; i += kWarpSize)
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
        }
        // This SM should be responsible for some destination experts, read `topk_idx` for them
lishen's avatar
lishen committed
378
        int expert_count[kMaxNumWarps] = {0};
379
380
381
382
383
384
385
386
        const auto expert_begin_idx = sm_id * num_warp_groups;
        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

        // Per lane count
        #pragma unroll 8
        for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
            auto idx = static_cast<int>(__ldg(topk_idx + i));
            if (idx >= expert_begin_idx and idx < expert_end_idx)
lishen's avatar
lishen committed
387
                expert_count[idx - expert_begin_idx] ++;
388
389
390
391
        }

        // Warp reduce
        #pragma unroll
lishen's avatar
lishen committed
392
        for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
            if (lane_id == 0) {
                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
            }
        }
    }
    __syncthreads();

    // Issue count sends
    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];

        // Wait local sends issued and send expert counts
        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
lishen's avatar
lishen committed
410
411
412
413
414
415
416
417
418

        auto dst_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
        // 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
        uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
        if (p2p_ptr == 0) {  // RDMA
            internode_ll_long_atomic_add(dst_ptr, -num_tokens_sent - 1, 
                                         num_ranks, dst_rank, dst_expert_local_idx);
        } else { //  本地 GPU 和 同一计算节点的 其他 GPU 地址
            st_na_release(reinterpret_cast<int *>(p2p_ptr), -num_tokens_sent - 1);
419
420
421
422
423
424
425
426
427
428
429
430
        }

        // Clean workspace for next use
        atomic_counter_per_expert[responsible_expert_idx] = 0;
        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

        // Clean `packed_recv_count`
        if (dst_rank == 0)
            packed_recv_count[dst_expert_local_idx] = 0;
    }
    syncwarp();

lishen's avatar
lishen committed
431
432
    // Receiving phase
LOW_LATENCY_DISPATCH_RECV:
433
434
435
436
437
438
439
440
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;

    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
    if (phases & LOW_LATENCY_SEND_PHASE){
        grid_barrier(global_atomic_counter, num_sms);
    }

lishen's avatar
lishen committed
441
442
443
444
    // 16 is the max possible number of warps in AMD GPUs
    constexpr int num_sync_large_iteration = kMaxNumWarps ;
    __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];

445
    #pragma unroll
lishen's avatar
lishen committed
446
447
448
449
450
    for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

451
452
453
454
    // Receiving and packing
    if (responsible_expert_idx < num_experts) {
        const auto src_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
lishen's avatar
lishen committed
455
        const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
lishen's avatar
lishen committed
456
457
                                       local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                       src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
lishen's avatar
lishen committed
458
        const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
lishen's avatar
lishen committed
459
                                 local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
460
461
        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
lishen's avatar
lishen committed
462
        const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
lishen's avatar
lishen committed
463
        const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
lishen's avatar
lishen committed
464
                                   local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
lishen's avatar
lishen committed
465
                                       (kQuantGroupSize == 0 ? 1 : num_aligned_scales);
466
467

        // Shared between sub-warps in warp groups
lishen's avatar
lishen committed
468
        __shared__ int shared_num_recv_tokens[kMaxNumWarps], shared_recv_token_begin_idx[kMaxNumWarps];
469
470
471

        // Wait tokens to arrive
        // NOTES: using sub-warp 1 to overlap with sub-warp 0
lishen's avatar
lishen committed
472
        int num_recv_tokens, recv_token_begin_idx;
lishen's avatar
lishen committed
473
        // EP_DEVICE_ASSERT(num_warps_per_group > 1);
474
475

        if (sub_warp_id == 1 and lane_id == 0) {
lishen's avatar
lishen committed
476
            while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
477
            num_recv_tokens = -num_recv_tokens - 1;
lishen's avatar
lishen committed
478
479
            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
480
            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
lishen's avatar
lishen committed
481
            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
482
483
484
485
        }

        // no needs to reset because there is no iteration
        if (lane_id == 0){
lishen's avatar
lishen committed
486
            volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
487
488
489
        }
        syncwarp();

lishen's avatar
lishen committed
490
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
491
492
493
494
        num_recv_tokens = shared_num_recv_tokens[warp_group_id];
        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];

        // Copy tokens
lishen's avatar
lishen committed
495
        EP_STATIC_ASSERT(kNumScales <= 64, "Invalid hidden size");
496
497
498
499
500
501
502
503
504
505
506
        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
            // Copy source info
            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
            if (lane_id == 0)
                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
            syncwarp();

            // Copy data
            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
lishen's avatar
lishen committed
507
            UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
508
509

            // Copy scales
510
            if constexpr(kUseQuant8Bit) {
511
                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
lishen's avatar
lishen committed
512
513
514
515
516
                const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
                const auto token_idx = recv_token_begin_idx + i;
                const auto token_stride = num_elems_per_pack;
                const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;

lishen's avatar
lishen committed
517
                if constexpr(kQuantGroupSize == 0) {
lishen's avatar
lishen committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                    if (lane_id == 0) {
                        recv_x_scales[token_idx] = ld_nc_global(src_scales);
                    }
                } else {
                    if (lane_id < kNumScales) {
                        const auto pack_idx = lane_id / num_elems_per_pack;
                        const auto elem_idx = lane_id % num_elems_per_pack;
                        auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
                        recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                    }
                    if (lane_id + kWarpSize < kNumScales) {
                        const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
                        const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
                        auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
                        recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                    }
lishen's avatar
lishen committed
534
                }
535
536
537
            }
        }
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
538
539
}

lishen's avatar
lishen committed
540
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
541
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
542
              int* packed_recv_count,
543
              int* global_atomic_counter,
lishen's avatar
lishen committed
544
545
546
547
              void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int64_t* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
lishen's avatar
lishen committed
548
              int num_topk, int num_experts, int rank, int num_ranks,
549
              int quant_type, int quant_group_size, bool fp8_round_scale,
lishen's avatar
lishen committed
550
              void* workspace, int num_device_sms,
lishen's avatar
lishen committed
551
              hipStream_t stream, int phases) {
552
    constexpr int kMaxNumWarps = 16;
553
    constexpr int kNumMaxTopK = 11;
lishen's avatar
lishen committed
554
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
555
    const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
556
557
558
559
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);

    const auto num_warps = num_warp_groups * num_warps_per_group;
lishen's avatar
lishen committed
560
    const auto num_sms = ceil_div(num_experts, num_warp_groups);
561
562
563
    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);

    // Workspace checks
lishen's avatar
lishen committed
564
    auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
565
566
567
    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);

568
569
570
571
572
573
    // 限制groupsize的大小
    EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);

    /*量化类型枚举
    0 -> None          不量化,保持原始精度
    1 -> Int8          使用 INT8 对称量化
lishen's avatar
lishen committed
574
    2 -> FP8_E4M3      使用 FP8 E4M3 格式 (__HIP_E4M3)
575
    3 -> FP8_UE8M0     使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
lishen's avatar
lishen committed
576
    4 -> FP8_E5M2      使用 FP8 E5M2 格式 (__HIP_E5M2)
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    */

#define DISPATCH_LAUNCH_CASE(hidden)                                                \
  {                                                                                 \
    auto dispatch_func = dispatch<hidden, 0, 0, kMaxNumWarps>;                      \
    if (quant_group_size == 0) {                                                    \
        switch (quant_type) {                                                       \
            case 1: dispatch_func = dispatch<hidden, 1, 0, kMaxNumWarps>; break;    \
            case 2: dispatch_func = dispatch<hidden, 2, 0, kMaxNumWarps>; break;    \
            case 3: dispatch_func = dispatch<hidden, 3, 0, kMaxNumWarps>; break;    \
            case 4: dispatch_func = dispatch<hidden, 4, 0, kMaxNumWarps>; break;    \
        }                                                                           \
    } else {                                                                        \
        switch (quant_type) {                                                       \
            case 1: dispatch_func = dispatch<hidden, 1, 128, kMaxNumWarps>; break;  \
            case 2: dispatch_func = dispatch<hidden, 2, 128, kMaxNumWarps>; break;  \
            case 3: dispatch_func = dispatch<hidden, 3, 128, kMaxNumWarps>; break;  \
            case 4: dispatch_func = dispatch<hidden, 4, 128, kMaxNumWarps>; break;  \
        }                                                                           \
    }                                                                               \
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func,                              \
        packed_recv_x, packed_recv_x_scales,                                        \
        packed_recv_src_info, packed_recv_layout_range, packed_recv_count,          \
        global_atomic_counter,                                                      \
        rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx,                          \
        atomic_counter_per_expert, atomic_finish_counter_per_expert,                \
        next_clean, num_next_clean_int,                                             \
        num_tokens, num_max_dispatch_tokens_per_rank,                               \
        num_topk, num_experts, rank, num_ranks,                                     \
        num_warp_groups, num_warps_per_group, fp8_round_scale, phases);             \
  }                                                                                 \
  break
609
610
611
612

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
613
614
}

615
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
lishen's avatar
lishen committed
616
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
617
618
619
620
621
combine(void* combined_x,
        void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
        const void* x, const int64_t* topk_idx, const float* topk_weights,
        const int* src_info, const int64_t* layout_range,
        int* global_atomic_counter,
lishen's avatar
lishen committed
622
        int64_t* combine_wait_recv_cost_stats,
lishen's avatar
lishen committed
623
624
625
626
627
        int64_t* next_clean, int num_next_clean_int,
        int* atomic_clean_flag,
        int num_combined_tokens, int hidden, int num_topk,
        int num_max_dispatch_tokens_per_rank,
        int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
628
        int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
629
630
631
632
633
634
635
        int phases, bool zero_copy) {
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_local_experts = num_experts / num_ranks;
lishen's avatar
lishen committed
636
637
638
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
lishen's avatar
lishen committed
639
640
641
642
643
644

    // Data type staffs
    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
    const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;

    // Message package
645
    EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
lishen's avatar
lishen committed
646
    constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
lishen's avatar
lishen committed
647
    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
lishen's avatar
lishen committed
648

649
    // 初始化用于细粒度warp间同步的计数器数组
lishen's avatar
lishen committed
650
651
652
653
654
655
656
657
    __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
    if (threadIdx.x==0){
        #pragma unroll
        for (int i = 0; i < kMaxNumWarps; ++i) {
            sync_large_warp_counters[i] = 0;
        }
    }
    __syncthreads();
658

lishen's avatar
lishen committed
659
660
661
    // Sending phase
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_COMBINE_RECV;
Chenggang Zhao's avatar
Chenggang Zhao committed
662

lishen's avatar
lishen committed
663
664
665
666
667
    // Clean up next buffer
    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
        #pragma unroll
        for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
            next_clean[i] = 0;
668

lishen's avatar
lishen committed
669
670
671
672
673
        // Notify before executing `int_p`
        syncwarp();
        if (lane_id == 0)
            atomic_add_release_global(atomic_clean_flag, num_experts);
    }
674

lishen's avatar
lishen committed
675
676
677
678
679
680
681
    // Issue IBGDA sends
    if (responsible_expert_idx < num_experts) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
        const auto local_x = reinterpret_cast<const int4*>(x) +
lishen's avatar
lishen committed
682
                             local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
lishen's avatar
lishen committed
683
684
        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
lishen's avatar
lishen committed
685
                                     local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
lishen's avatar
lishen committed
686
687
688
689
690
691

        // Unpack layout
        int offset, num_tokens_to_send;
        unpack2(layout, num_tokens_to_send, offset);

        // Issue IBGDA send
lishen's avatar
lishen committed
692
        for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
lishen's avatar
lishen committed
693
694
            const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
            const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
lishen's avatar
lishen committed
695
            const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
lishen's avatar
lishen committed
696
697

            // Copy directly to local rank, or copy to buffer and issue RDMA
698
            const auto src_idx = __ldg(local_src_info + token_idx);
lishen's avatar
lishen committed
699
            const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
lishen's avatar
lishen committed
700
            const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
lishen's avatar
lishen committed
701
702
703
            
            uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
            if (p2p_ptr == 0) {  // RDMA
lishen's avatar
lishen committed
704
705
706
                const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
                if (not zero_copy)
                    UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
707

lishen's avatar
lishen committed
708
709
710
711
712
713
714
715
                internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
                    num_ranks, dst_rank, local_expert_idx,
                    hidden * sizeof(hip_bfloat16));
            } else { //  本地 GPU 和 同一计算节点的 其他 GPU 地址
                // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
                const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
                UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
lishen's avatar
lishen committed
716
            }
lishen's avatar
lishen committed
717
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
718

lishen's avatar
lishen committed
719
        // Put finishing flag
lishen's avatar
lishen committed
720
        // EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
721
        if (lane_id == 0){
lishen's avatar
lishen committed
722
            volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
lishen's avatar
lishen committed
723
724
        }
        syncwarp();
lishen's avatar
lishen committed
725
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
lishen's avatar
lishen committed
726

lishen's avatar
lishen committed
727
728
        if (sub_warp_id == 1 and lane_id == 0) {
            while (ld_acquire_global(atomic_clean_flag) == 0);
lishen's avatar
lishen committed
729
730
731
732
733
734
735
736

            auto dst_ptr = rdma_recv_flag + global_expert_idx;
            // 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
            uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
            if (p2p_ptr == 0) {  // RDMA
                internode_ll_long_atomic_add(dst_ptr, 1, num_ranks, dst_rank, local_expert_idx);
            } else { //  本地 GPU 和 同一计算节点的 其他 GPU 地址
                st_na_release(reinterpret_cast<int *>(p2p_ptr), 1);
lishen's avatar
lishen committed
737
            }
lishen's avatar
lishen committed
738

lishen's avatar
lishen committed
739
740
741
            atomic_add_release_global(atomic_clean_flag, -1);
        }
        syncwarp();
742
743
    }

lishen's avatar
lishen committed
744
    // Receiving phase
lishen's avatar
lishen committed
745
LOW_LATENCY_COMBINE_RECV:
lishen's avatar
lishen committed
746
747
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;
748

lishen's avatar
lishen committed
749
750
    // Wait all ranks to arrive and notify PCIe usage
    if (responsible_expert_idx < num_experts) {
lishen's avatar
lishen committed
751
        // EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        if (sub_warp_id == 0 and lane_id == 0) {
            const auto src_rank = responsible_expert_idx / num_local_experts;
            auto start_time = wall_clock64();
            uint64_t wait_recv_cost = 0;
            while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0  // recv not ready
                   && (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES   // not timeout
            );

            // Mask rank if timeout
            if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
                printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
                       rank, responsible_expert_idx % num_local_experts, src_rank);
            }

            if (combine_wait_recv_cost_stats != nullptr) {
                atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
            }
lishen's avatar
lishen committed
769
        }
770
    }
lishen's avatar
lishen committed
771
772
773
    grid_barrier(global_atomic_counter, num_sms);

    // Reduce tokens with FP8 cast
lishen's avatar
lishen committed
774
    // EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
lishen's avatar
lishen committed
775
776
777
778
779
780
781
782
783
784
785
    EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
    if (thread_id < hidden_bf16_int4) {
        for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
            // Read top-k indices and weights
            int reg_topk_idx[kNumMaxTopk];
            float reg_topk_weights[kNumMaxTopk];
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) {
                reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
                reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
            }
786

lishen's avatar
lishen committed
787
788
789
790
            float combined_values[kNumElemsPerInt4] = {0.0f};
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
                // Read from sources
791
792
                auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
                    (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
lishen's avatar
lishen committed
793
                auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
lishen's avatar
lishen committed
794
795
796
797
798
799
800
801

                // Reduce
                auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
                const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
                #pragma unroll
                for (int j = 0; j < kNumElemsPerInt4; ++ j)
                    combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
            }
802

lishen's avatar
lishen committed
803
804
805
806
807
808
809
810
811
            // Write results
            int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
            auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
            #pragma unroll
            for (int j = 0; j < kNumElemsPerInt4; ++ j)
                combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
            (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
        }
    }
812
813
}

lishen's avatar
lishen committed
814
815
816
817
818
void combine(void* combined_x,
             void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* global_atomic_counter,
lishen's avatar
lishen committed
819
             int64_t* combine_wait_recv_cost_stats,
lishen's avatar
lishen committed
820
821
822
             int64_t* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
823
             void* workspace, int num_device_sms, hipStream_t stream,
lishen's avatar
lishen committed
824
             int phases, bool zero_copy) {
825
    constexpr int kMaxNumWarps = 16;
lishen's avatar
lishen committed
826
827
    constexpr int kNumMaxTopk = 11;
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
828
    const int num_warps_per_group = kMaxNumWarps / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
lishen's avatar
lishen committed
829
830
    const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
lishen's avatar
lishen committed
831

lishen's avatar
lishen committed
832
833
834
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_sms =
        max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
lishen's avatar
lishen committed
835
836
837
838
839
840

    // Check workspace
    auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);

841
842
843
844
845
846
847
848
849
850
851
852
853
854
#define COMBINE_LAUNCH_CASE(hidden)                                            \
  {                                                                            \
    auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>;            \
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func,                          \
        combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x,                  \
        x, topk_idx, topk_weights, src_info, layout_range,                     \
        global_atomic_counter, combine_wait_recv_cost_stats,                   \
        next_clean, num_next_clean_int,                                        \
        atomic_clean_flag, num_combined_tokens, hidden,                        \
        num_topk, num_max_dispatch_tokens_per_rank,                            \
        num_experts, rank, num_ranks,                                          \
        num_warp_groups, num_warps_per_group, phases, zero_copy);              \
  }                                                                            \
  break
lishen's avatar
lishen committed
855
856
857
858

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
859
860
}

Chenggang Zhao's avatar
Chenggang Zhao committed
861
862
863
} // namespace internode_ll

} // namespace deep_ep