internode_ll.cu 33.9 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
4
5
6
7
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
lishen's avatar
lishen committed
8
9
10

#include "hip/hip_runtime.h"

11
12
13
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX

14
#include "shmem_wrapper.cuh"
Chenggang Zhao's avatar
Chenggang Zhao committed
15
16
17
18
19

namespace deep_ep {

namespace internode_ll {

lishen's avatar
lishen committed
20
21
22
23
24
25
26
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    dtype_b_t packed;
    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
    unpacked_ptr[0] = x, unpacked_ptr[1] = y;
    return packed;
27
28
29
30
31
}

__device__ void grid_barrier(int* global_counter, int num_blocks) {
    volatile int ret;
    __syncthreads();
lishen's avatar
lishen committed
32
    __threadfence();
33
    if (threadIdx.x == 0 ) {
lishen's avatar
lishen committed
34
35
        // ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
        ret = atomicAdd(&global_counter[0], 1);
36
37
38
    }
    __syncthreads();
    if (threadIdx.x == 0) {
lishen's avatar
lishen committed
39
        while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
40
41
42
    }
    __syncthreads();
}
lishen's avatar
lishen committed
43
44
45
46
47
48
49
50
51
52
53
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
    return (a + b - 1) / b;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
    x = unpacked_ptr[0], y = unpacked_ptr[1];
}
54
55


Chenggang Zhao's avatar
Chenggang Zhao committed
56
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
57
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
lishen's avatar
lishen committed
58
                                         int64_t* clean_1, int num_clean_int_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
59
    // Barrier before cleaning (in case of unfinished chunked EP)
lishen's avatar
lishen committed
60
    if (threadIdx.x == 0)
61
        internode::shmem_device_barrier_all();
62
63

    // Clean
lishen's avatar
lishen committed
64
    auto thread_id = static_cast<int>(threadIdx.x);
65
66
67
68
69
70
71
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
        clean_0[i] = 0;
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
        clean_1[i] = 0;

lishen's avatar
lishen committed
72
    // Barrier after cleaning (make sure low-latency mode work
lishen's avatar
lishen committed
73
    if (threadIdx.x == 0)
74
        internode::shmem_device_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
75
76
}

77
78
79
80
81
82
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
                              int64_t* clean_1, int num_clean_int_1,
                              hipStream_t stream) {
    constexpr int kNumThreads = 256;

    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
lishen's avatar
lishen committed
83
84
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
                  clean_0, num_clean_int_0, clean_1, num_clean_int_1);
Chenggang Zhao's avatar
Chenggang Zhao committed
85
86
}

lishen's avatar
lishen committed
87
template <bool kUseFP8, bool kUseUE8M0, int kHidden>
lishen's avatar
lishen committed
88
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
89
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
90
91
92
93
94
95
96
97
98
         int* packed_recv_src_info, int64_t* packed_recv_layout_range,
         int* packed_recv_count,
         int* global_atomic_counter,
         void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
         const void* x, const int64_t* topk_idx,
         int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
         int64_t* next_clean, int num_next_clean_int,
         int num_tokens, int num_max_dispatch_tokens_per_rank,
         int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
99
         int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
100
         bool round_scale, int phases) {
101
102
103
104
105
106
107
108
109
110
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;

lishen's avatar
lishen committed
111
112
113
114
115
    // May extract UE8M0 from the scales
    using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
    using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
    EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");

116
117
118
119
120
121
    // FP8 staffs
    constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
    const int num_scales = kHidden / kNumPerChannels;
    const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
    const size_t hidden_int4 = hidden_bytes / sizeof(int4);

lishen's avatar
lishen committed
122
    // Message package: hidden data, FP8 scales, index at source
123
    // NOTES: currently we have 3 reserved int fields for future use
lishen's avatar
lishen committed
124
    using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
125
126
127
128
    const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
    const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
    EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

lishen's avatar
lishen committed
129
130
131
132
133
    // Expert counts
    constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
    __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

    // Sending phase
134
135
136
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_DISPATCH_RECV;

lishen's avatar
lishen committed
137
138
139
140
141
#if !defined(ROCM_DISABLE_CTX)
    __shared__ internode::shmem_ctx_t ctx;
    internode::shmem_wg_ctx_create(&ctx);
#endif

142
143
144
    // There are 2 kinds of warps in this part:
    // 1. The first-kind warps for FP8 cast and sending top-k tokens
    // 2. The last warp for reading `topk_idx` and count for per-expert information
lishen's avatar
lishen committed
145
146
147
    if (warp_id < num_warps) {
        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
        EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
148
149
150
151
152
        EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
        const auto num_threads = (num_warps - 1) * kWarpSize;
        const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;

        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
lishen's avatar
lishen committed
153
154
            const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
            const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
155
156
157
            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);

lishen's avatar
lishen committed
158
            // Overlap top-k index read and source token index write
159
160
161
162
163
164
165
166
167
            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

            // FP8 cast
            #pragma unroll
            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                // Read
                auto int4_value = __ldg(x_int4 + i);

lishen's avatar
lishen committed
168
                if (kUseFP8) {
169
170
171
172
173
                    // Calculate local amax
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
                    float fp32_values[kNumElemsPerRead];
                    float amax = kFP8Margin, scale, scale_inv;
                    #pragma unroll
lishen's avatar
lishen committed
174
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
175
176
177
178
179
180
                        fp32_values[j] = static_cast<float>(bf16_values[j]);
                        amax = fmaxf(amax, fabsf(fp32_values[j]));
                    }
                    // Reduce amax and scale
                    EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
                    amax = warp_reduce_max<16>(amax);
lishen's avatar
lishen committed
181
                    calculate_fp8_scales(amax, scale, scale_inv, round_scale);
182
                    if (lane_id % 16 == 0)
lishen's avatar
lishen committed
183
                        rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv;
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

                    // Cast into send buffer
                    vec_t int2_value;
                    auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
                    #pragma unroll
                    for (int j = 0; j < kNumElemsPerRead; j += 2) {
                        float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
                        fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
                    }
                    rdma_x_vec[i] = int2_value;
                } else {
                    // Reinterpret-cast is for C++14 compatibility
                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
                }
            }
            __syncthreads();
            // Issue IBGDA sends
            if (dst_expert_idx >= 0) {
                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
                slot_idx = shfl_sync(slot_idx, 0);
lishen's avatar
lishen committed
204
205
                const auto dst_rank = dst_expert_idx / num_local_experts;
                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
206
207
208
                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
                                     dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
lishen's avatar
lishen committed
209
210
                                     rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                     slot_idx * num_bytes_per_msg;
211
                if (dst_rank != rank) {
lishen's avatar
lishen committed
212
213
214
215
216
217
218
219
220
221
222
223
#if !defined(ROCM_DISABLE_CTX)
                    internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else
                    internode::shmemx_int8_put_nbi_warp(
#endif
                        reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
                        num_bytes_per_msg, dst_rank);
                    // #if !defined(ROCM_DISABLE_CTX)
                    //                     internode::shmem_ctx_quiet(ctx);
                    // #else
                    //                     internode::shmem_fence();
                    // #endif
224
225
226
227
                } else {
                    // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                    const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
                    const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
lishen's avatar
lishen committed
228
                    UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
229
                }
lishen's avatar
lishen committed
230

231
232
233
234
235
                // Increase counter after finishing
                syncwarp();
                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
            }
        }
lishen's avatar
lishen committed
236
237
    }
    if (warp_id == num_warps - 1) {
238
239
        EP_DEVICE_ASSERT(num_sms > 1);
        if (sm_id == 0) {
lishen's avatar
lishen committed
240
            // The first SM is also responsible for checking QPs
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            // The first SM is also responsible for cleaning the next buffer
            #pragma unroll
            for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
                next_clean[i] = 0;

            // Notify before executing `int_p`
            syncwarp();
            #pragma unroll
            for (int i = lane_id; i < num_experts; i += kWarpSize)
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
        }
        // This SM should be responsible for some destination experts, read `topk_idx` for them
        int expert_count[kNumMaxWarpGroups] = {0};
        const auto expert_begin_idx = sm_id * num_warp_groups;
        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

        // Per lane count
        #pragma unroll 8
        for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
            auto idx = static_cast<int>(__ldg(topk_idx + i));
            if (idx >= expert_begin_idx and idx < expert_end_idx)
lishen's avatar
lishen committed
262
                expert_count[idx - expert_begin_idx] ++;
263
264
265
266
        }

        // Warp reduce
        #pragma unroll
lishen's avatar
lishen committed
267
        for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
            if (lane_id == 0) {
                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
            }
        }
    }
    __syncthreads();

    // Issue count sends
    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];

        // Wait local sends issued and send expert counts
        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
lishen's avatar
lishen committed
285
        if (dst_rank != rank) {
lishen's avatar
lishen committed
286
287
288
289
290
291
#if !defined(ROCM_DISABLE_CTX)
            internode::shmem_ctx_long_atomic_add(ctx,
#else
            internode::shmem_long_atomic_add(
#endif
                rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
lishen's avatar
lishen committed
292
293
        } else {
            st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
294
295
296
297
298
299
300
301
302
303
304
305
        }

        // Clean workspace for next use
        atomic_counter_per_expert[responsible_expert_idx] = 0;
        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

        // Clean `packed_recv_count`
        if (dst_rank == 0)
            packed_recv_count[dst_expert_local_idx] = 0;
    }
    syncwarp();

lishen's avatar
lishen committed
306
307
308
309
#if !defined(ROCM_DISABLE_CTX)
    internode::shmem_wg_ctx_destroy(&ctx);
#endif

lishen's avatar
lishen committed
310
311
    // Receiving phase
LOW_LATENCY_DISPATCH_RECV:
312
313
314
315
316
317
318
319
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;

    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
    if (phases & LOW_LATENCY_SEND_PHASE){
        grid_barrier(global_atomic_counter, num_sms);
    }

lishen's avatar
lishen committed
320
321
322
323
324
325
326
327
328
329
330
    // 16 is the max possible number of warps in AMD GPUs
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    constexpr int num_sync_large_iteration = kMaxNumWarps ;
    __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];

#pragma unroll
    for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

331
332
333
334
    // Receiving and packing
    if (responsible_expert_idx < num_experts) {
        const auto src_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
lishen's avatar
lishen committed
335
        const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
lishen's avatar
lishen committed
336
337
                                       local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                       src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
lishen's avatar
lishen committed
338
        const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
lishen's avatar
lishen committed
339
                                 local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
340
341
        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
lishen's avatar
lishen committed
342
343
        const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t));
        const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
lishen's avatar
lishen committed
344
                                   local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales;
345
346
347
348
349
350

        // Shared between sub-warps in warp groups
        __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];

        // Wait tokens to arrive
        // NOTES: using sub-warp 1 to overlap with sub-warp 0
lishen's avatar
lishen committed
351
352
        int num_recv_tokens, recv_token_begin_idx;
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
353
354

        if (sub_warp_id == 1 and lane_id == 0) {
lishen's avatar
lishen committed
355
            while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
356
            num_recv_tokens = -num_recv_tokens - 1;
lishen's avatar
lishen committed
357
358
            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
359
            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
lishen's avatar
lishen committed
360
            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
361
362
363
364
        }

        // no needs to reset because there is no iteration
        if (lane_id == 0){
lishen's avatar
lishen committed
365
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
366
367
368
369
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();

lishen's avatar
lishen committed
370
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        num_recv_tokens = shared_num_recv_tokens[warp_group_id];
        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];

        // Copy tokens
        EP_DEVICE_ASSERT(num_scales <= 64);
        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
            // Copy source info
            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
            if (lane_id == 0)
                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
            syncwarp();

            // Copy data
            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
lishen's avatar
lishen committed
387
            UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
388
389

            // Copy scales
lishen's avatar
lishen committed
390
            if (kUseFP8) {
391
                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
lishen's avatar
lishen committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
                const auto token_idx = recv_token_begin_idx + i;
                const auto token_stride = num_elems_per_pack;
                const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;

                if (lane_id < num_scales) {
                    const auto pack_idx = lane_id / num_elems_per_pack;
                    const auto elem_idx = lane_id % num_elems_per_pack;
                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                }
                if (lane_id + kWarpSize < num_scales) {
                    const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
                    const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
                    auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
                    recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
                }
409
410
411
            }
        }
    }
Chenggang Zhao's avatar
Chenggang Zhao committed
412
413
}

lishen's avatar
lishen committed
414
void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
lishen's avatar
lishen committed
415
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
416
              int* packed_recv_count,
417
              int* global_atomic_counter,
lishen's avatar
lishen committed
418
419
420
421
              void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int64_t* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
lishen's avatar
lishen committed
422
              int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
423
              bool use_fp8, bool round_scale, bool use_ue8m0,
lishen's avatar
lishen committed
424
              void* workspace, int num_device_sms,
lishen's avatar
lishen committed
425
              hipStream_t stream, int phases) {
426
    constexpr int kNumMaxTopK = 11;
lishen's avatar
lishen committed
427
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
lishen's avatar
lishen committed
428
    const int num_warps_per_group = 16 / num_warp_groups;
429
430
431
432
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);

    const auto num_warps = num_warp_groups * num_warps_per_group;
lishen's avatar
lishen committed
433
    const auto num_sms = ceil_div(num_experts, num_warp_groups);
434
435
436
    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);

    // Workspace checks
lishen's avatar
lishen committed
437
    auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
438
439
440
    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);

lishen's avatar
lishen committed
441
#define DISPATCH_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
442
443
444
445
446
auto dispatch_func = dispatch<false, false, hidden>; \
if (use_fp8 and not use_ue8m0)                       \
    dispatch_func = dispatch<true, false, hidden>;   \
if (use_fp8 and use_ue8m0)                           \
    dispatch_func = dispatch<true, true, hidden>;    \
lishen's avatar
lishen committed
447
448
449
450
451
452
453
454
455
456
457
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
              packed_recv_x, packed_recv_x_scales, \
              packed_recv_src_info, packed_recv_layout_range, \
              packed_recv_count, \
              global_atomic_counter, \
              rdma_recv_x, rdma_recv_count, rdma_x, \
              x, topk_idx, \
              atomic_counter_per_expert, atomic_finish_counter_per_expert, \
              next_clean, num_next_clean_int, \
              num_tokens, num_max_dispatch_tokens_per_rank, \
              num_topk, num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
458
              num_warp_groups, num_warps_per_group, round_scale, phases); } break
459
460
461
462

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
463
464
}

lishen's avatar
lishen committed
465
466
template <int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
lishen's avatar
lishen committed
467
468
469
470
471
472
473
474
475
476
combine(void* combined_x,
        void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
        const void* x, const int64_t* topk_idx, const float* topk_weights,
        const int* src_info, const int64_t* layout_range,
        int* global_atomic_counter,
        int64_t* next_clean, int num_next_clean_int,
        int* atomic_clean_flag,
        int num_combined_tokens, int hidden, int num_topk,
        int num_max_dispatch_tokens_per_rank,
        int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
477
        int num_warp_groups, int num_warps_per_group,
lishen's avatar
lishen committed
478
479
480
481
482
483
484
        int phases, bool zero_copy) {
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_local_experts = num_experts / num_ranks;
lishen's avatar
lishen committed
485
486
487
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
lishen's avatar
lishen committed
488
489
490
491
492
493

    // Data type staffs
    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
    const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;

    // Message package
lishen's avatar
lishen committed
494
    EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
lijian6's avatar
lijian6 committed
495
    constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
lishen's avatar
lishen committed
496
    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
lishen's avatar
lishen committed
497

lishen's avatar
lishen committed
498
    // 16 is the max possible number of warps in AMD GPUs
lishen's avatar
lishen committed
499
500
501
502
503
504
505
506
507
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
    if (threadIdx.x==0){
        #pragma unroll
        for (int i = 0; i < kMaxNumWarps; ++i) {
            sync_large_warp_counters[i] = 0;
        }
    }
    __syncthreads();
508

lishen's avatar
lishen committed
509
510
511
    // Sending phase
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_COMBINE_RECV;
Chenggang Zhao's avatar
Chenggang Zhao committed
512

lishen's avatar
lishen committed
513
514
515
516
517
#if !defined(ROCM_DISABLE_CTX)
    __shared__ internode::shmem_ctx_t ctx;
    internode::shmem_wg_ctx_create(&ctx);
#endif

lishen's avatar
lishen committed
518
519
520
521
522
    // Clean up next buffer
    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
        #pragma unroll
        for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
            next_clean[i] = 0;
523

lishen's avatar
lishen committed
524
525
526
527
528
        // Notify before executing `int_p`
        syncwarp();
        if (lane_id == 0)
            atomic_add_release_global(atomic_clean_flag, num_experts);
    }
529

lishen's avatar
lishen committed
530
531
532
533
534
535
536
    // Issue IBGDA sends
    if (responsible_expert_idx < num_experts) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
        const auto local_x = reinterpret_cast<const int4*>(x) +
lishen's avatar
lishen committed
537
                             local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
lishen's avatar
lishen committed
538
539
        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
lishen's avatar
lishen committed
540
                                     local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
lishen's avatar
lishen committed
541
542
543
544
545
546

        // Unpack layout
        int offset, num_tokens_to_send;
        unpack2(layout, num_tokens_to_send, offset);

        // Issue IBGDA send
lishen's avatar
lishen committed
547
        for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) {
lishen's avatar
lishen committed
548
549
550
551
552
            const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
            const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
            const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);

            // Copy directly to local rank, or copy to buffer and issue RDMA
553
            const auto src_idx = __ldg(local_src_info + token_idx);
lishen's avatar
lishen committed
554
555
556
557
558
559
560
561
562
            const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
            const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
            if (dst_rank == rank) {
                const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
                UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
            } else {
                const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
                if (not zero_copy)
                    UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
563

lishen's avatar
lishen committed
564
565
566
                    //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if !defined(ROCM_DISABLE_CTX)
                internode::shmem_ctx_schar_put_nbi_warp(ctx,
lishen's avatar
lishen committed
567
#else
lishen's avatar
lishen committed
568
                internode::shmemx_int8_put_nbi_warp(
lishen's avatar
lishen committed
569
#endif
lishen's avatar
lishen committed
570
571
572
                    reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
                    hidden * sizeof(hip_bfloat16), dst_rank);
            }
lishen's avatar
lishen committed
573
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
574

lishen's avatar
lishen committed
575
        // Put finishing flag
lishen's avatar
lishen committed
576
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
577
578
579
580
581
        if (lane_id == 0){
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();
lishen's avatar
lishen committed
582
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
lishen's avatar
lishen committed
583

lishen's avatar
lishen committed
584
585
586
        if (sub_warp_id == 1 and lane_id == 0) {
            while (ld_acquire_global(atomic_clean_flag) == 0);
            if (dst_rank != rank) {
lishen's avatar
lishen committed
587
588
#if !defined(ROCM_DISABLE_CTX)
                internode::shmem_ctx_long_atomic_add(ctx,
lishen's avatar
lishen committed
589
#else
lishen's avatar
lishen committed
590
                internode::shmem_long_atomic_add(
lishen's avatar
lishen committed
591
#endif
lishen's avatar
lishen committed
592
                    rdma_recv_flag + global_expert_idx, 1, dst_rank);
lishen's avatar
lishen committed
593
594
595
596
597
598
            } else {
                st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
            }
            atomic_add_release_global(atomic_clean_flag, -1);
        }
        syncwarp();
lishen's avatar
lishen committed
599
600
601
602
603
604
605
606

        if (num_ranks > 8){
#if !defined(ROCM_DISABLE_CTX)
            internode::shmem_ctx_quiet(ctx);
#else
            internode::shmem_fence();
#endif
        }
607
608
    }

lishen's avatar
lishen committed
609
610
611
612
#if !defined(ROCM_DISABLE_CTX)
    internode::shmem_wg_ctx_destroy(&ctx);
#endif

lishen's avatar
lishen committed
613
    // Receiving phase
lishen's avatar
lishen committed
614
LOW_LATENCY_COMBINE_RECV:
lishen's avatar
lishen committed
615
616
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;
617

lishen's avatar
lishen committed
618
619
620
621
622
623
624
625
    //     if (num_ranks > 8){
    // #if !defined(ROCM_DISABLE_CTX)
    //         internode::shmem_ctx_quiet(ctx);
    // #else
    //         internode::shmem_fence();
    // #endif
    //     }

lishen's avatar
lishen committed
626
627
    // Wait all ranks to arrive and notify PCIe usage
    if (responsible_expert_idx < num_experts) {
lishen's avatar
lishen committed
628
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
lishen's avatar
lishen committed
629
630
631
        if (sub_warp_id == 0 and lane_id == 0){
            while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
        }
632
    }
lishen's avatar
lishen committed
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
    grid_barrier(global_atomic_counter, num_sms);

    // Reduce tokens with FP8 cast
    EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
    EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
    if (thread_id < hidden_bf16_int4) {
        for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
            // Read top-k indices and weights
            int reg_topk_idx[kNumMaxTopk];
            float reg_topk_weights[kNumMaxTopk];
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) {
                reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
                reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
            }
648

lishen's avatar
lishen committed
649
650
651
652
653
654
655
656
657
658
659
660
661
662
            float combined_values[kNumElemsPerInt4] = {0.0f};
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
                // Read from sources
                auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
                auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);

                // Reduce
                auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
                const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
                #pragma unroll
                for (int j = 0; j < kNumElemsPerInt4; ++ j)
                    combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
            }
663

lishen's avatar
lishen committed
664
665
666
667
668
669
670
671
672
            // Write results
            int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
            auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
            #pragma unroll
            for (int j = 0; j < kNumElemsPerInt4; ++ j)
                combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
            (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
        }
    }
673
674
}

lishen's avatar
lishen committed
675
676
677
678
679
680
681
682
void combine(void* combined_x,
             void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* global_atomic_counter,
             int64_t* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
lishen's avatar
lishen committed
683
             void* workspace, int num_device_sms, hipStream_t stream,
lishen's avatar
lishen committed
684
             int phases, bool zero_copy) {
lishen's avatar
lishen committed
685
686
687
688
689
    constexpr int kNumMaxTopk = 11;
    const int num_warp_groups = ceil_div(num_experts, num_device_sms);
    const int num_warps_per_group = 16 / num_warp_groups; // num_warps_per_group>1, "Requires more than one warp per group"
    const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
lishen's avatar
lishen committed
690

lishen's avatar
lishen committed
691
692
693
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_sms =
        max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
lishen's avatar
lishen committed
694
695
696
697
698
699
700

    // Check workspace
    auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);

#define COMBINE_LAUNCH_CASE(hidden) { \
lishen's avatar
lishen committed
701
auto combine_func = combine<hidden, kNumMaxTopk>; \
lishen's avatar
lishen committed
702
703
704
705
706
707
708
709
710
711
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
              combined_x, \
              rdma_recv_x, rdma_recv_flag, rdma_send_x, \
              x, topk_idx, topk_weights, src_info, layout_range, \
              global_atomic_counter, \
              next_clean, num_next_clean_int, \
              atomic_clean_flag, \
              num_combined_tokens, hidden, num_topk, \
              num_max_dispatch_tokens_per_rank, \
              num_experts, rank, num_ranks, \
lishen's avatar
lishen committed
712
              num_warp_groups, num_warps_per_group, phases, zero_copy); } break
lishen's avatar
lishen committed
713
714
715
716

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
717
718
}

Chenggang Zhao's avatar
Chenggang Zhao committed
719
720
721
} // namespace internode_ll

} // namespace deep_ep