internode_ll.cu 31.6 KB
Newer Older
Chenggang Zhao's avatar
Chenggang Zhao committed
1
2
3
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
4
5
6
7
#include "buffer.cuh"
#include "utils.cuh"
// #include <cooperative_groups.h>
#include <iostream>
lishen's avatar
lishen committed
8
9
10

#include "hip/hip_runtime.h"

11
12
13
14
// low latency+RocSHMEM has issue with CTX.
#define ROCM_DISABLE_CTX

#include <rocshmem/rocshmem.hpp>
Chenggang Zhao's avatar
Chenggang Zhao committed
15

lishen's avatar
lishen committed
16
using namespace rocshmem;
Chenggang Zhao's avatar
Chenggang Zhao committed
17
18
19
20
namespace deep_ep {

namespace internode_ll {

lishen's avatar
lishen committed
21
22
23
24
25
26
27
template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ dtype_b_t pack2(const dtype_a_t& x, const dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    dtype_b_t packed;
    auto unpacked_ptr = reinterpret_cast<dtype_a_t*>(&packed);
    unpacked_ptr[0] = x, unpacked_ptr[1] = y;
    return packed;
28
29
30
31
32
}

__device__ void grid_barrier(int* global_counter, int num_blocks) {
    volatile int ret;
    __syncthreads();
lishen's avatar
lishen committed
33
    __threadfence();
34
    if (threadIdx.x == 0 ) {
lishen's avatar
lishen committed
35
36
        // ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
        ret = atomicAdd(&global_counter[0], 1);
37
38
39
    }
    __syncthreads();
    if (threadIdx.x == 0) {
lishen's avatar
lishen committed
40
            while (__hip_atomic_load(global_counter, __ATOMIC_RELAXED,__HIP_MEMORY_SCOPE_AGENT) != num_blocks);
41
42
43
    }
    __syncthreads();
}
lishen's avatar
lishen committed
44
45
46
47
48
49
50
51
52
53
54
template <typename dtype_t>
__host__ __device__ dtype_t ceil_div(dtype_t a, dtype_t b) {
    return (a + b - 1) / b;
}

template <typename dtype_a_t, typename dtype_b_t>
__device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, dtype_a_t& y) {
    EP_STATIC_ASSERT(sizeof(dtype_a_t) * 2 == sizeof(dtype_b_t), "Invalid dtypes");
    auto unpacked_ptr = reinterpret_cast<const dtype_a_t*>(&packed);
    x = unpacked_ptr[0], y = unpacked_ptr[1];
}
55
56


Chenggang Zhao's avatar
Chenggang Zhao committed
57
template <int kNumThreads> __launch_bounds__(kNumThreads, 1)
58
__global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
lishen's avatar
lishen committed
59
                                         int64_t* clean_1, int num_clean_int_1) {
Chenggang Zhao's avatar
Chenggang Zhao committed
60
    // Barrier before cleaning (in case of unfinished chunked EP)
lishen's avatar
lishen committed
61
62
    if (threadIdx.x == 0)
        rocshmem::rocshmem_barrier_all();
63
64

    // Clean
lishen's avatar
lishen committed
65
    auto thread_id = static_cast<int>(threadIdx.x);
66
67
68
69
70
71
72
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_0; i += kNumThreads)
        clean_0[i] = 0;
    #pragma unroll
    for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
        clean_1[i] = 0;

lishen's avatar
lishen committed
73
74
75
    // Barrier after cleaning (make sure low-latency mode work 
    if (threadIdx.x == 0)
        rocshmem::rocshmem_barrier_all();
Chenggang Zhao's avatar
Chenggang Zhao committed
76
77
}

78
79
80
81
82
83
void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
                              int64_t* clean_1, int num_clean_int_1,
                              hipStream_t stream) {
    constexpr int kNumThreads = 256;

    SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
lishen's avatar
lishen committed
84
85
    LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, clean_low_latency_buffer<kNumThreads>,
                  clean_0, num_clean_int_0, clean_1, num_clean_int_1);
Chenggang Zhao's avatar
Chenggang Zhao committed
86
87
}

lishen's avatar
lishen committed
88
89
90
91
92
93
94
95
96
97
98
99
100
template <bool kUseFP8, int kHidden>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, float* packed_recv_x_scales,
         int* packed_recv_src_info, int64_t* packed_recv_layout_range,
         int* packed_recv_count,
         int* global_atomic_counter,
         void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
         const void* x, const int64_t* topk_idx,
         int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
         int64_t* next_clean, int num_next_clean_int,
         int num_tokens, int num_max_dispatch_tokens_per_rank,
         int num_topk, int num_experts, int rank, int num_ranks,
         int num_warp_groups, int num_warps_per_group, int phases) {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#if !defined(ROCM_DISABLE_CTX)
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif

    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto num_warps = num_warp_groups * num_warps_per_group;
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / num_warps_per_group;
    const auto sub_warp_id = warp_id % num_warps_per_group;
    const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;

    // FP8 staffs
    constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
    const int num_scales = kHidden / kNumPerChannels;
    const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
    const size_t hidden_int4 = hidden_bytes / sizeof(int4);

lishen's avatar
lishen committed
122
    // Message package: hidden data, FP8 scales, index at source
123
    // NOTES: currently we have 3 reserved int fields for future use
lishen's avatar
lishen committed
124
    using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
125
126
127
128
    const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
    const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
    EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

lishen's avatar
lishen committed
129
    // 16 is the max possible number of warps in AMD GPUs 
130
131
132
133
134
135
136
137
138
139
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    constexpr int num_sync_large_iteration = kMaxNumWarps ;
    __shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];

    #pragma unroll
    for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
        sync_large_warp_counters[i] = 0;
    }
    __syncthreads();

lishen's avatar
lishen committed
140
141
142
143
144
    // Expert counts
    constexpr int kNumMaxWarpGroups = 1024 / kWarpSize;
    __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

    // Sending phase
145
146
147
148
149
150
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_DISPATCH_RECV;

    // There are 2 kinds of warps in this part:
    // 1. The first-kind warps for FP8 cast and sending top-k tokens
    // 2. The last warp for reading `topk_idx` and count for per-expert information
lishen's avatar
lishen committed
151
152
153
    if (warp_id < num_warps) {
        constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
        EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
154
155
156
157
158
        EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
        const auto num_threads = (num_warps - 1) * kWarpSize;
        const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead;

        for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
lishen's avatar
lishen committed
159
160
            const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
            const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
161
162
163
            const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
            const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);

lishen's avatar
lishen committed
164
            // Overlap top-k index read and source token index write
165
166
167
168
169
170
171
172
173
            auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
            thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

            // FP8 cast
            #pragma unroll
            for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
                // Read
                auto int4_value = __ldg(x_int4 + i);

lishen's avatar
lishen committed
174
                if (kUseFP8) {
175
176
177
178
179
                    // Calculate local amax
                    auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
                    float fp32_values[kNumElemsPerRead];
                    float amax = kFP8Margin, scale, scale_inv;
                    #pragma unroll
lishen's avatar
lishen committed
180
                    for (int j = 0; j < kNumElemsPerRead; ++ j) {
181
182
183
184
185
186
                        fp32_values[j] = static_cast<float>(bf16_values[j]);
                        amax = fmaxf(amax, fabsf(fp32_values[j]));
                    }
                    // Reduce amax and scale
                    EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
                    amax = warp_reduce_max<16>(amax);
lishen's avatar
lishen committed
187
                    calculate_fp8_scales</*round_scale*/false>(amax, scale, scale_inv);
188
                    if (lane_id % 16 == 0)
lishen's avatar
lishen committed
189
                        rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

                    // Cast into send buffer
                    vec_t int2_value;
                    auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
                    #pragma unroll
                    for (int j = 0; j < kNumElemsPerRead; j += 2) {
                        float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
                        fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
                    }
                    rdma_x_vec[i] = int2_value;
                } else {
                    // Reinterpret-cast is for C++14 compatibility
                    rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
                }
            }
            __syncthreads();
            // Issue IBGDA sends
            if (dst_expert_idx >= 0) {
                int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
                slot_idx = shfl_sync(slot_idx, 0);
lishen's avatar
lishen committed
210
211
                const auto dst_rank = dst_expert_idx / num_local_experts;
                const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
212
213
214
                const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
                const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
                                     dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
lishen's avatar
lishen committed
215
216
                                     rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                                     slot_idx * num_bytes_per_msg;
217
                if (dst_rank != rank) {
lishen's avatar
lishen committed
218
219
                    rocshmem::rocshmem_schar_put_nbi_wave(reinterpret_cast<signed char*>(dst_ptr), 
                        reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank);
220
221
222
223
224
                    rocshmem::rocshmem_fence();
                } else {
                    // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
                    const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
                    const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
lishen's avatar
lishen committed
225
                    UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
226
                }
lishen's avatar
lishen committed
227

228
229
230
231
232
                // Increase counter after finishing
                syncwarp();
                lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
            }
        }
lishen's avatar
lishen committed
233
234
    }
    if (warp_id == num_warps - 1) {
235
236
        EP_DEVICE_ASSERT(num_sms > 1);
        if (sm_id == 0) {
lishen's avatar
lishen committed
237
            // The first SM is also responsible for checking QPs
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            // The first SM is also responsible for cleaning the next buffer
            #pragma unroll
            for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
                next_clean[i] = 0;

            // Notify before executing `int_p`
            syncwarp();
            #pragma unroll
            for (int i = lane_id; i < num_experts; i += kWarpSize)
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
        }
        // This SM should be responsible for some destination experts, read `topk_idx` for them
        int expert_count[kNumMaxWarpGroups] = {0};
        const auto expert_begin_idx = sm_id * num_warp_groups;
        const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

        // Per lane count
        #pragma unroll 8
        for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
            auto idx = static_cast<int>(__ldg(topk_idx + i));
            if (idx >= expert_begin_idx and idx < expert_end_idx)
lishen's avatar
lishen committed
259
                expert_count[idx - expert_begin_idx] ++;
260
261
262
263
        }

        // Warp reduce
        #pragma unroll
lishen's avatar
lishen committed
264
        for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
265
266
267
268
269
270
271
272
            auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
            if (lane_id == 0) {
                shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
                atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
            }
        }
    }

lishen's avatar
lishen committed
273
    //revert sync_large_warp_counters to 0 for next sync
274
275
276
277
278
279
280
281
282
283
    __syncthreads();

    // Issue count sends
    if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
        const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];

        // Wait local sends issued and send expert counts
        while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
lishen's avatar
lishen committed
284
285
286
287
        if (dst_rank != rank) {
           rocshmem::rocshmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
        } else {
            st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
288
289
290
291
292
293
294
295
296
297
298
299
        }

        // Clean workspace for next use
        atomic_counter_per_expert[responsible_expert_idx] = 0;
        atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

        // Clean `packed_recv_count`
        if (dst_rank == 0)
            packed_recv_count[dst_expert_local_idx] = 0;
    }
    syncwarp();

lishen's avatar
lishen committed
300
301
    // Receiving phase
LOW_LATENCY_DISPATCH_RECV:
302
303
304
305
306
307
308
309
310
311
312
313
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;

    // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
    if (phases & LOW_LATENCY_SEND_PHASE){
        grid_barrier(global_atomic_counter, num_sms);
    }

    // Receiving and packing
    if (responsible_expert_idx < num_experts) {
        const auto src_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
lishen's avatar
lishen committed
314
315
316
317
318
319
        const auto rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
                src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
        const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
        const auto recv_x_scales = packed_recv_x_scales + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_scales;
320
321
322
323
324
325
326
327
        const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;

        // Shared between sub-warps in warp groups
        __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];

        // Wait tokens to arrive
        // NOTES: using sub-warp 1 to overlap with sub-warp 0
lishen's avatar
lishen committed
328
329
        int num_recv_tokens, recv_token_begin_idx;
        EP_DEVICE_ASSERT(num_warps_per_group > 1);
330
331

        if (sub_warp_id == 1 and lane_id == 0) {
lishen's avatar
lishen committed
332
            while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
333
            num_recv_tokens = -num_recv_tokens - 1;
lishen's avatar
lishen committed
334
335
            recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
            shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
336
            shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
lishen's avatar
lishen committed
337
            recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
338
339
340
341
        }

        // no needs to reset because there is no iteration
        if (lane_id == 0){
lishen's avatar
lishen committed
342
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
343
344
345
346
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();

lishen's avatar
lishen committed
347
        while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        num_recv_tokens = shared_num_recv_tokens[warp_group_id];
        recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];

        // Copy tokens
        EP_DEVICE_ASSERT(num_scales <= 64);
        for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
            // Copy source info
            const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
            if (lane_id == 0)
                recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
            syncwarp();

            // Copy data
            // NOTES: only 2 load iterations for 7K hidden with 7 unrolls
            const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
            const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
lishen's avatar
lishen committed
364
            UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
365
366

            // Copy scales
lishen's avatar
lishen committed
367
            if (kUseFP8) {
368
                const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
lishen's avatar
lishen committed
369
370
371
372
373
374
                const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i);
                const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank;
                auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0;
                auto scale_1 = (lane_id + kWarpSize) < num_scales ? ld_nc_global(src_scales + lane_id + kWarpSize) : 0;
                lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f;
                (lane_id + kWarpSize) < num_scales ? dst_scales[(lane_id + kWarpSize) * scale_stride] = scale_1 : 0.0f;
375
376
377
378
379
380
381
            }
        }
    }

#if !defined(ROCM_DISABLE_CTX)
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
Chenggang Zhao's avatar
Chenggang Zhao committed
382
383
}

lishen's avatar
lishen committed
384
385
void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
              int* packed_recv_src_info, int64_t* packed_recv_layout_range,
386
              int* packed_recv_count,
387
              int* global_atomic_counter,
lishen's avatar
lishen committed
388
389
390
391
392
393
              void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
              const void* x, const int64_t* topk_idx,
              int64_t* next_clean, int num_next_clean_int,
              int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
              int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
              void* workspace, hipStream_t stream, int phases) {
394
    constexpr int kNumMaxTopK = 11;
lishen's avatar
lishen committed
395
396
    const int num_warp_groups = ceil_div(num_experts, /*num_device_sms*/80);
    const int num_warps_per_group = 16 / num_warp_groups;
397
398
399
400
    EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
    EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);

    const auto num_warps = num_warp_groups * num_warps_per_group;
lishen's avatar
lishen committed
401
    const auto num_sms = ceil_div(num_experts, num_warp_groups);
402
403
404
    EP_HOST_ASSERT(num_topk <= kNumMaxTopK);

    // Workspace checks
lishen's avatar
lishen committed
405
    auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
406
407
408
    auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
    EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);

lishen's avatar
lishen committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = use_fp8 ? dispatch<true, hidden> : \
                               dispatch<false, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
              packed_recv_x, packed_recv_x_scales, \
              packed_recv_src_info, packed_recv_layout_range, \
              packed_recv_count, \
              global_atomic_counter, \
              rdma_recv_x, rdma_recv_count, rdma_x, \
              x, topk_idx, \
              atomic_counter_per_expert, atomic_finish_counter_per_expert, \
              next_clean, num_next_clean_int, \
              num_tokens, num_max_dispatch_tokens_per_rank, \
              num_topk, num_experts, rank, num_ranks, \
              num_warp_groups, num_warps_per_group, phases); } break
424
425
426
427

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
#undef DISPATCH_LAUNCH_CASE
428
429
}

lishen's avatar
lishen committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
template <int kNumWarpGroups, int kNumWarpsPerGroup, int kHidden, int kNumMaxTopk>
__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * kWarpSize, 1) void
combine(void* combined_x,
        void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
        const void* x, const int64_t* topk_idx, const float* topk_weights,
        const int* src_info, const int64_t* layout_range,
        int* global_atomic_counter,
        int64_t* next_clean, int num_next_clean_int,
        int* atomic_clean_flag,
        int num_combined_tokens, int hidden, int num_topk,
        int num_max_dispatch_tokens_per_rank,
        int num_experts, int rank, int num_ranks,
        int phases, bool zero_copy) {

444
445
446
447
#if !defined(ROCM_DISABLE_CTX)
    __shared__ rocshmem::rocshmem_ctx_t ctx;
    rocshmem::rocshmem_wg_ctx_create(0, &ctx);
#endif
lishen's avatar
lishen committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    const auto sm_id = static_cast<int>(blockIdx.x);
    const auto num_sms = static_cast<int>(gridDim.x);
    const auto thread_id = static_cast<int>(threadIdx.x);
    const auto num_threads = static_cast<int>(blockDim.x);
    const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
    const auto num_local_experts = num_experts / num_ranks;
    const auto warp_group_id = warp_id / kNumWarpsPerGroup;
    const auto sub_warp_id = warp_id % kNumWarpsPerGroup;
    const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id;

    // Data type staffs
    constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
    const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;

    // Message package
    // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot)
    constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
    EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
    __syncthreads();
    // 16 is the max possible number of warps in AMD GPUs 
    constexpr int kMaxNumWarps = 1024 / kWarpSize;
    __shared__ volatile int sync_large_warp_counters[kMaxNumWarps];
    if (threadIdx.x==0){
        #pragma unroll
        for (int i = 0; i < kMaxNumWarps; ++i) {
            sync_large_warp_counters[i] = 0;
        }
    }
    __syncthreads();
477

lishen's avatar
lishen committed
478
479
480
    // Sending phase
    if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
        goto LOW_LATENCY_COMBINE_RECV;
Chenggang Zhao's avatar
Chenggang Zhao committed
481

lishen's avatar
lishen committed
482
483
484
485
486
    // Clean up next buffer
    if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
        #pragma unroll
        for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
            next_clean[i] = 0;
487

lishen's avatar
lishen committed
488
489
490
491
492
        // Notify before executing `int_p`
        syncwarp();
        if (lane_id == 0)
            atomic_add_release_global(atomic_clean_flag, num_experts);
    }
493

lishen's avatar
lishen committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    // Issue IBGDA sends
    if (responsible_expert_idx < num_experts) {
        const auto dst_rank = responsible_expert_idx / num_local_experts;
        const auto local_expert_idx = responsible_expert_idx % num_local_experts;
        const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
        const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
        const auto local_x = reinterpret_cast<const int4*>(x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
        const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
        const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
                local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;

        // Unpack layout
        int offset, num_tokens_to_send;
        unpack2(layout, num_tokens_to_send, offset);

        // Issue IBGDA send
        for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += kNumWarpsPerGroup) {
            const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
            const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
            const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row + 4);

            // Copy directly to local rank, or copy to buffer and issue RDMA
            auto src_idx = __ldg(local_src_info + token_idx);
            const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
            const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4);
            if (dst_rank == rank) {
                const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
                UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
            } else {
                const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
                if (not zero_copy)
                    UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
                
                //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(hip_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
#if defined(ROCM_DISABLE_CTX)
                    rocshmem::rocshmem_schar_put_nbi_wave(
#else
                    rocshmem::rocshmem_ctx_schar_put_nbi_wave(ctx,
#endif
                    reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), hidden * sizeof(hip_bfloat16), dst_rank);
535

lishen's avatar
lishen committed
536
537
538
539
540
541
542
#if defined(ROCM_DISABLE_CTX)
                    rocshmem::rocshmem_fence();
#else
                    rocshmem::rocshmem_ctx_quiet(ctx);
#endif
                }
        }
Chenggang Zhao's avatar
Chenggang Zhao committed
543

lishen's avatar
lishen committed
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        // Put finishing flag
        EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
        if (lane_id == 0){
            // volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
            volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
        }
        syncwarp();
        while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup));
        if (sub_warp_id == 1 and lane_id == 0) {
            while (ld_acquire_global(atomic_clean_flag) == 0);
            if (dst_rank != rank) {
#if defined(ROCM_DISABLE_CTX)
                rocshmem::rocshmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank);
#else
                rocshmem::rocshmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank);
#endif
            } else {
                st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
            }
            atomic_add_release_global(atomic_clean_flag, -1);
        }
        syncwarp();
566
567
    }

lishen's avatar
lishen committed
568
569
570
571
    // Receiving phase
    LOW_LATENCY_COMBINE_RECV:
    if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
        return;
572

lishen's avatar
lishen committed
573
574
575
576
577
578
    // Wait all ranks to arrive and notify PCIe usage
    if (responsible_expert_idx < num_experts) {
        EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
        if (sub_warp_id == 0 and lane_id == 0){
            while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0);
        }
579
    }
lishen's avatar
lishen committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    grid_barrier(global_atomic_counter, num_sms);

    // Reduce tokens with FP8 cast
    EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
    EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
    if (thread_id < hidden_bf16_int4) {
        for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
            // Read top-k indices and weights
            int reg_topk_idx[kNumMaxTopk];
            float reg_topk_weights[kNumMaxTopk];
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) {
                reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
                reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
            }
595

lishen's avatar
lishen committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            float combined_values[kNumElemsPerInt4] = {0.0f};
            #pragma unroll
            for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
                // Read from sources
                auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
                auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);

                // Reduce
                auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
                const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
                #pragma unroll
                for (int j = 0; j < kNumElemsPerInt4; ++ j)
                    combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
            }
610

lishen's avatar
lishen committed
611
612
613
614
615
616
617
618
619
620
621
622
            // Write results
            int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
            auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
            #pragma unroll
            for (int j = 0; j < kNumElemsPerInt4; ++ j)
                combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
            (reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
        }
    }
#if !defined(ROCM_DISABLE_CTX)
    rocshmem::rocshmem_wg_ctx_destroy(&ctx);
#endif
623
624
}

lishen's avatar
lishen committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
void combine(void* combined_x,
             void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
             const void* x, const int64_t* topk_idx, const float* topk_weights,
             const int* src_info, const int64_t* layout_range,
             int* global_atomic_counter,
             int64_t* next_clean, int num_next_clean_int,
             int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
             int num_topk, int num_experts, int rank, int num_ranks,
             void* workspace, hipStream_t stream,
             int phases, bool zero_copy) {
    constexpr int kNumWarpsPerGroup = 4;
    constexpr int kNumWarpGroups = 4;
    constexpr int kNumMaxTopk = 9;

    const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
    const auto num_sms = ceil_div(num_experts, kNumWarpGroups);

    // Check workspace
    auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
    EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
    EP_HOST_ASSERT(num_topk <= kNumMaxTopk);

#define COMBINE_LAUNCH_CASE(hidden) { \
auto combine_func = combine<kNumWarpGroups, kNumWarpsPerGroup, hidden, kNumMaxTopk>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
              combined_x, \
              rdma_recv_x, rdma_recv_flag, rdma_send_x, \
              x, topk_idx, topk_weights, src_info, layout_range, \
              global_atomic_counter, \
              next_clean, num_next_clean_int, \
              atomic_clean_flag, \
              num_combined_tokens, hidden, num_topk, \
              num_max_dispatch_tokens_per_rank, \
              num_experts, rank, num_ranks, \
              phases, zero_copy); } break

    SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
    SWITCH_HIDDEN(COMBINE_LAUNCH_CASE);
#undef COMBINE_LAUNCH_CASE
664
665
}

Chenggang Zhao's avatar
Chenggang Zhao committed
666
667
668
} // namespace internode_ll

} // namespace deep_ep