internode_ll.cu 33.2 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
4
5
6
7
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
lishen's avatar
lishen committed
8
9
10

#include "hip/hip_runtime.h"

11
12
13
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX

14
#include "shmem_wrapper.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
15
16
17
18
19

namespace deep_ep {

namespace internode_ll {

lishen's avatar
lishen committed
20
21
22
23
24
25
26
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    dtype_b_t packed;
    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
    unpacked_ptr[0] = x, unpacked_ptr[1] = y;
    return packed;
27
28
29
30
31
}

__device__ void grid_barrier(int* global_counter, int num_blocks) {
    volatile int ret;
    __syncthreads();
lishen's avatar
lishen committed
32
    __threadfence();
33
    if (threadIdx.x == 0 ) {
lishen's avatar
lishen committed
34
35
        // ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
        ret = atomicAdd(&global_counter[0], 1);
36
37
38
    }
    __syncthreads();
    if (threadIdx.x == 0) {
lishen's avatar
lishen committed
39
            while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
40
41
42
    }
    __syncthreads();
}
lishen's avatar
lishen committed
43
44
45
46
47
48
49
50
51
52
53
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
    return (a + b - 1) / b;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
    x = unpacked_ptr[0], y = unpacked_ptr[1];
}
54
55


Chenggang Zhao's avatar
Chenggang Zhao committed
56
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
57
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
lishen's avatar
lishen committed
58
                                         int64_t* clean_1, int num_clean_int_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
59
    // Barrier before cleaning (in case of unfinished chunked EP)
lishen's avatar
lishen committed
60
    if (threadIdx.x == 0)
61
        internode::shmem_device_barrier_all();
62
63

    // Clean
lishen's avatar
lishen committed
64
    auto thread_id = static_cast<int>(threadIdx.x);
65
66
67
68
69
70
71
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
        clean_0[i] = 0;
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
        clean_1[i] = 0;

lishen's avatar
lishen committed
72
73
    // Barrier after cleaning (make sure low-latency mode work 
    if (threadIdx.x == 0)
74
        internode::shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
75
76
}

77
78
79
80
81
82
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
                              int64_t* clean_1, int num_clean_int_1,
                              hipStream_t stream) {
    constexpr int kNumThreads = 256;

    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
lishen's avatar
lishen committed
83
84
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
                  clean_0, num_clean_int_0, clean_1, num_clean_int_1);
Chenggang Zhao's avatar
Chenggang Zhao committed
85
86
}

lishen's avatar
lishen committed
87
template <bool kUseFP8, bool kUseUE8M0, int kHidden>
lishen's avatar
lishen committed
88
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
89
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
90
91
92
93
94
95
96
97
98
         int* packed_recv_src_info, int64_t* packed_recv_layout_range,
         int* packed_recv_count,
         int* global_atomic_counter,
         void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
         const void* x, const int64_t* topk_idx,
         int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
         int64_t* next_clean, int num_next_clean_int,
         int num_tokens, int num_max_dispatch_tokens_per_rank,
         int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
99
100
         int num_warp_groups, int num_warps_per_group, 
         bool round_scale, int phases) {
101
#if !defined(ROCM_DISABLE_CTX)
102
103
    __shared__ internode::shmem_ctx_t ctx;
    internode::shmem_wg_ctx_create(&ctx);
104
105
106
107
108
109
110
111
112
113
114
115
#endif

    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;

lishen's avatar
lishen committed
116
117
118
119
120
    // May extract UE8M0 from the scales
    using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
    using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
    EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");

121
122
123
124
125
126
    // FP8 staffs
    constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
    const int num_scales = kHidden / kNumPerChannels;
    const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
    const size_t hidden_int4 = hidden_bytes / sizeof(int4);

lishen's avatar
lishen committed
127
    // Message package: hidden data, FP8 scales, index at source
128
    // NOTES: currently we have 3 reserved int fields for future use
lishen's avatar
lishen committed
129
    using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
130
131
132
133
    const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
    const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
    EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

lishen's avatar
lishen committed
134
    // 16 is the max possible number of warps in AMD GPUs 
135
136
137
138
139
140
141
142
143
144
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    constexpr int num_sync_large_iteration = kMaxNumWarps ;
    __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];

    #pragma unroll
    for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

lishen's avatar
lishen committed
145
146
147
148
149
    // Expert counts
    constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
    __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

    // Sending phase
150
151
152
153
154
155
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_DISPATCH_RECV;

    // There are 2 kinds of warps in this part:
    // 1. The first-kind warps for FP8 cast and sending top-k tokens
    // 2. The last warp for reading `topk_idx` and count for per-expert information
lishen's avatar
lishen committed
156
157
158
    if (warp_id < num_warps) {
        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
        EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
159
160
161
162
163
        EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
        const auto num_threads = (num_warps - 1) * kWarpSize;
        const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;

        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
lishen's avatar
lishen committed
164
165
            const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
            const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
166
167
168
            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);

lishen's avatar
lishen committed
169
            // Overlap top-k index read and source token index write
170
171
172
173
174
175
176
177
178
            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

            // FP8 cast
            #pragma unroll
            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                // Read
                auto int4_value = __ldg(x_int4 + i);

lishen's avatar
lishen committed
179
                if (kUseFP8) {
180
181
182
183
184
                    // Calculate local amax
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
                    float fp32_values[kNumElemsPerRead];
                    float amax = kFP8Margin, scale, scale_inv;
                    #pragma unroll
lishen's avatar
lishen committed
185
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
186
187
188
189
190
191
                        fp32_values[j] = static_cast<float>(bf16_values[j]);
                        amax = fmaxf(amax, fabsf(fp32_values[j]));
                    }
                    // Reduce amax and scale
                    EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
                    amax = warp_reduce_max<16>(amax);
lishen's avatar
lishen committed
192
                    calculate_fp8_scales(amax, scale, scale_inv, round_scale);
193
                    if (lane_id % 16 == 0)
lishen's avatar
lishen committed
194
                        rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv;
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

                    // Cast into send buffer
                    vec_t int2_value;
                    auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
                    #pragma unroll
                    for (int j = 0; j < kNumElemsPerRead; j += 2) {
                        float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
                        fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
                    }
                    rdma_x_vec[i] = int2_value;
                } else {
                    // Reinterpret-cast is for C++14 compatibility
                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
                }
            }
            __syncthreads();
            // Issue IBGDA sends
            if (dst_expert_idx >= 0) {
                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
                slot_idx = shfl_sync(slot_idx, 0);
lishen's avatar
lishen committed
215
216
                const auto dst_rank = dst_expert_idx / num_local_experts;
                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
217
218
219
                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
                                     dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
lishen's avatar
lishen committed
220
221
                                     rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                     slot_idx * num_bytes_per_msg;
222
                if (dst_rank != rank) {
223
                    internode::shmemx_int8_put_nbi_warp(reinterpret_cast<signed char*>(dst_ptr), 
lishen's avatar
lishen committed
224
                        reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
225
                    internode::shmem_fence();
226
227
228
229
                } else {
                    // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                    const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
                    const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
lishen's avatar
lishen committed
230
                    UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
231
                }
lishen's avatar
lishen committed
232

233
234
235
236
237
                // Increase counter after finishing
                syncwarp();
                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
            }
        }
lishen's avatar
lishen committed
238
239
    }
    if (warp_id == num_warps - 1) {
240
241
        EP_DEVICE_ASSERT(num_sms > 1);
        if (sm_id == 0) {
lishen's avatar
lishen committed
242
            // The first SM is also responsible for checking QPs
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            // The first SM is also responsible for cleaning the next buffer
            #pragma unroll
            for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
                next_clean[i] = 0;

            // Notify before executing `int_p`
            syncwarp();
            #pragma unroll
            for (int i = lane_id; i < num_experts; i += kWarpSize)
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
        }
        // This SM should be responsible for some destination experts, read `topk_idx` for them
        int expert_count[kNumMaxWarpGroups] = {0};
        const auto expert_begin_idx = sm_id * num_warp_groups;
        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

        // Per lane count
        #pragma unroll 8
        for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
            auto idx = static_cast<int>(__ldg(topk_idx + i));
            if (idx >= expert_begin_idx and idx < expert_end_idx)
lishen's avatar
lishen committed
264
                expert_count[idx - expert_begin_idx] ++;
265
266
267
268
        }

        // Warp reduce
        #pragma unroll
lishen's avatar
lishen committed
269
        for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
270
271
272
273
274
275
276
277
            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
            if (lane_id == 0) {
                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
            }
        }
    }

lishen's avatar
lishen committed
278
    //revert sync_large_warp_counters to 0 for next sync
279
280
281
282
283
284
285
286
287
288
    __syncthreads();

    // Issue count sends
    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];

        // Wait local sends issued and send expert counts
        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
lishen's avatar
lishen committed
289
        if (dst_rank != rank) {
290
            internode::shmem_long_atomic_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
lishen's avatar
lishen committed
291
292
        } else {
            st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
293
294
295
296
297
298
299
300
301
302
303
304
        }

        // Clean workspace for next use
        atomic_counter_per_expert[responsible_expert_idx] = 0;
        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

        // Clean `packed_recv_count`
        if (dst_rank == 0)
            packed_recv_count[dst_expert_local_idx] = 0;
    }
    syncwarp();

lishen's avatar
lishen committed
305
306
    // Receiving phase
LOW_LATENCY_DISPATCH_RECV:
307
308
309
310
311
312
313
314
315
316
317
318
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;

    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
    if (phases & LOW_LATENCY_SEND_PHASE){
        grid_barrier(global_atomic_counter, num_sms);
    }

    // Receiving and packing
    if (responsible_expert_idx < num_experts) {
        const auto src_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
lishen's avatar
lishen committed
319
320
321
322
323
        const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
        const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
324
325
        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
lishen's avatar
lishen committed
326
327
328
        const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
        const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
            local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
329
330
331
332
333
334

        // Shared between sub-warps in warp groups
        __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];

        // Wait tokens to arrive
        // NOTES: using sub-warp 1 to overlap with sub-warp 0
lishen's avatar
lishen committed
335
336
        int num_recv_tokens, recv_token_begin_idx;
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
337
338

        if (sub_warp_id == 1 and lane_id == 0) {
lishen's avatar
lishen committed
339
            while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
340
            num_recv_tokens = -num_recv_tokens - 1;
lishen's avatar
lishen committed
341
342
            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
343
            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
lishen's avatar
lishen committed
344
            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
345
346
347
348
        }

        // no needs to reset because there is no iteration
        if (lane_id == 0){
lishen's avatar
lishen committed
349
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
350
351
352
353
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();

lishen's avatar
lishen committed
354
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        num_recv_tokens = shared_num_recv_tokens[warp_group_id];
        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];

        // Copy tokens
        EP_DEVICE_ASSERT(num_scales <= 64);
        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
            // Copy source info
            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
            if (lane_id == 0)
                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
            syncwarp();

            // Copy data
            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
lishen's avatar
lishen committed
371
            UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
372
373

            // Copy scales
lishen's avatar
lishen committed
374
            if (kUseFP8) {
375
                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
lishen's avatar
lishen committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
                const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
                const auto token_idx = recv_token_begin_idx + i;
                const auto token_stride = num_elems_per_pack;
                const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;

                if (lane_id < num_scales) {
                    const auto pack_idx = lane_id / num_elems_per_pack;
                    const auto elem_idx = lane_id % num_elems_per_pack;
                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                }
                if (lane_id + kWarpSize < num_scales) {
                    const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
                    const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                }
393
394
395
396
397
            }
        }
    }

#if !defined(ROCM_DISABLE_CTX)
398
    internode::shmem_wg_ctx_destroy(&ctx);
399
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
400
401
}

lishen's avatar
lishen committed
402
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
403
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
404
              int* packed_recv_count,
405
              int* global_atomic_counter,
lishen's avatar
lishen committed
406
407
408
409
              void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int64_t* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
lishen's avatar
lishen committed
410
411
412
413
              int num_topk, int num_experts, int rank, int num_ranks, 
              bool use_fp8, bool round_scale, bool use_ue8m0,
              void* workspace, int num_device_sms, 
              hipStream_t stream, int phases) {
414
    constexpr int kNumMaxTopK = 11;
lishen's avatar
lishen committed
415
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
lishen's avatar
lishen committed
416
    const int num_warps_per_group = 16 / num_warp_groups;
417
418
419
420
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);

    const auto num_warps = num_warp_groups * num_warps_per_group;
lishen's avatar
lishen committed
421
    const auto num_sms = ceil_div(num_experts, num_warp_groups);
422
423
424
    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);

    // Workspace checks
lishen's avatar
lishen committed
425
    auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
426
427
428
    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);

lishen's avatar
lishen committed
429
#define DISPATCH_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
430
431
432
433
434
auto dispatch_func = dispatch<false, false, hidden>; \
if (use_fp8 and not use_ue8m0)                       \
    dispatch_func = dispatch<true, false, hidden>;   \
if (use_fp8 and use_ue8m0)                           \
    dispatch_func = dispatch<true, true, hidden>;    \
lishen's avatar
lishen committed
435
436
437
438
439
440
441
442
443
444
445
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
              packed_recv_x, packed_recv_x_scales, \
              packed_recv_src_info, packed_recv_layout_range, \
              packed_recv_count, \
              global_atomic_counter, \
              rdma_recv_x, rdma_recv_count, rdma_x, \
              x, topk_idx, \
              atomic_counter_per_expert, atomic_finish_counter_per_expert, \
              next_clean, num_next_clean_int, \
              num_tokens, num_max_dispatch_tokens_per_rank, \
              num_topk, num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
446
              num_warp_groups, num_warps_per_group, round_scale, phases); } break
447
448
449
450

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
451
452
}

lishen's avatar
lishen committed
453
454
template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
455
456
457
458
459
460
461
462
463
464
combine(void* combined_x,
        void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
        const void* x, const int64_t* topk_idx, const float* topk_weights,
        const int* src_info, const int64_t* layout_range,
        int* global_atomic_counter,
        int64_t* next_clean, int num_next_clean_int,
        int* atomic_clean_flag,
        int num_combined_tokens, int hidden, int num_topk,
        int num_max_dispatch_tokens_per_rank,
        int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
465
        int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
466
467
        int phases, bool zero_copy) {

468
#if !defined(ROCM_DISABLE_CTX)
469
470
    __shared__ internode::shmem_ctx_t ctx;
    internode::shmem_wg_ctx_create(&ctx);
471
#endif
lishen's avatar
lishen committed
472
473
474
475
476
477
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_local_experts = num_experts / num_ranks;
lishen's avatar
lishen committed
478
479
480
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
lishen's avatar
lishen committed
481
482
483
484
485
486

    // Data type staffs
    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
    const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;

    // Message package
lishen's avatar
lishen committed
487
    EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
lijian6's avatar
lijian6 committed
488
    constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
lishen's avatar
lishen committed
489
    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
lishen's avatar
lishen committed
490

lishen's avatar
lishen committed
491
492
493
494
495
496
497
498
499
500
    // 16 is the max possible number of warps in AMD GPUs 
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
    if (threadIdx.x==0){
        #pragma unroll
        for (int i = 0; i < kMaxNumWarps; ++i) {
            sync_large_warp_counters[i] = 0;
        }
    }
    __syncthreads();
501

lishen's avatar
lishen committed
502
503
504
    // Sending phase
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_COMBINE_RECV;
Chenggang Zhao's avatar
Chenggang Zhao committed
505

lishen's avatar
lishen committed
506
507
508
509
510
    // Clean up next buffer
    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
        #pragma unroll
        for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
            next_clean[i] = 0;
511

lishen's avatar
lishen committed
512
513
514
515
516
        // Notify before executing `int_p`
        syncwarp();
        if (lane_id == 0)
            atomic_add_release_global(atomic_clean_flag, num_experts);
    }
517

lishen's avatar
lishen committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    // Issue IBGDA sends
    if (responsible_expert_idx < num_experts) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
        const auto local_x = reinterpret_cast<const int4*>(x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;

        // Unpack layout
        int offset, num_tokens_to_send;
        unpack2(layout, num_tokens_to_send, offset);

        // Issue IBGDA send
lishen's avatar
lishen committed
535
        for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
lishen's avatar
lishen committed
536
537
538
539
540
            const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
            const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
            const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);

            // Copy directly to local rank, or copy to buffer and issue RDMA
541
            const auto src_idx = __ldg(local_src_info + token_idx);
lishen's avatar
lishen committed
542
543
544
545
546
547
548
549
550
551
552
553
            const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
            const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
            if (dst_rank == rank) {
                const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
                UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
            } else {
                const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
                if (not zero_copy)
                    UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
                
                //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
554
                    internode::shmemx_int8_put_nbi_warp(
lishen's avatar
lishen committed
555
#else
556
                    internode::shmem_ctx_schar_put_nbi_warp(ctx,
lishen's avatar
lishen committed
557
558
#endif
                    reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
559

lishen's avatar
lishen committed
560
#if defined(ROCM_DISABLE_CTX)
561
                    internode::shmem_fence();
lishen's avatar
lishen committed
562
#else
563
                    internode::shmem_ctx_quiet(ctx);
lishen's avatar
lishen committed
564
565
566
#endif
                }
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
567

lishen's avatar
lishen committed
568
        // Put finishing flag
lishen's avatar
lishen committed
569
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
570
571
572
573
574
        if (lane_id == 0){
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();
lishen's avatar
lishen committed
575
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
lishen's avatar
lishen committed
576
577
578
579
        if (sub_warp_id == 1 and lane_id == 0) {
            while (ld_acquire_global(atomic_clean_flag) == 0);
            if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX)
580
                internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
lishen's avatar
lishen committed
581
#else
582
                internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
lishen's avatar
lishen committed
583
584
585
586
587
588
589
#endif
            } else {
                st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
            }
            atomic_add_release_global(atomic_clean_flag, -1);
        }
        syncwarp();
590
591
    }

lishen's avatar
lishen committed
592
593
594
595
    // Receiving phase
    LOW_LATENCY_COMBINE_RECV:
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;
596

lishen's avatar
lishen committed
597
598
    // Wait all ranks to arrive and notify PCIe usage
    if (responsible_expert_idx < num_experts) {
lishen's avatar
lishen committed
599
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
600
601
602
        if (sub_warp_id == 0 and lane_id == 0){
            while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
        }
603
    }
lishen's avatar
lishen committed
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
    grid_barrier(global_atomic_counter, num_sms);

    // Reduce tokens with FP8 cast
    EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
    EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
    if (thread_id < hidden_bf16_int4) {
        for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
            // Read top-k indices and weights
            int reg_topk_idx[kNumMaxTopk];
            float reg_topk_weights[kNumMaxTopk];
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) {
                reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
                reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
            }
619

lishen's avatar
lishen committed
620
621
622
623
624
625
626
627
628
629
630
631
632
633
            float combined_values[kNumElemsPerInt4] = {0.0f};
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
                // Read from sources
                auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
                auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);

                // Reduce
                auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
                const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
                #pragma unroll
                for (int j = 0; j < kNumElemsPerInt4; ++ j)
                    combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
            }
634

lishen's avatar
lishen committed
635
636
637
638
639
640
641
642
643
644
            // Write results
            int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
            auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
            #pragma unroll
            for (int j = 0; j < kNumElemsPerInt4; ++ j)
                combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
            (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
        }
    }
#if !defined(ROCM_DISABLE_CTX)
645
    internode::shmem_wg_ctx_destroy(&ctx);
lishen's avatar
lishen committed
646
#endif
647
648
}

lishen's avatar
lishen committed
649
650
651
652
653
654
655
656
void combine(void* combined_x,
             void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* global_atomic_counter,
             int64_t* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
657
             void* workspace, int num_device_sms, hipStream_t stream,
lishen's avatar
lishen committed
658
             int phases, bool zero_copy) {
lishen's avatar
lishen committed
659
660
661
662
663
    constexpr int kNumMaxTopk = 11;
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
    const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
    const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
lishen's avatar
lishen committed
664

lishen's avatar
lishen committed
665
666
667
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_sms =
        max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
lishen's avatar
lishen committed
668
669
670
671
672
673
674

    // Check workspace
    auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);

#define COMBINE_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
675
auto combine_func = combine<hidden, kNumMaxTopk>; \
lishen's avatar
lishen committed
676
677
678
679
680
681
682
683
684
685
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
              combined_x, \
              rdma_recv_x, rdma_recv_flag, rdma_send_x, \
              x, topk_idx, topk_weights, src_info, layout_range, \
              global_atomic_counter, \
              next_clean, num_next_clean_int, \
              atomic_clean_flag, \
              num_combined_tokens, hidden, num_topk, \
              num_max_dispatch_tokens_per_rank, \
              num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
686
              num_warp_groups, num_warps_per_group, phases, zero_copy); } break
lishen's avatar
lishen committed
687
688
689
690

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
691
692
}

Chenggang Zhao's avatar
Chenggang Zhao committed
693
694
695
} // namespace internode_ll

} // namespace deep_ep