custom_all_reduce.cuh 145 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
#pragma once
/*
3
4
 * Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
 * Copyright (C) 2024-2026, The vLLM team.
Xiaowei.zhang's avatar
Xiaowei.zhang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "aiter_hip_common.h"
#include "hip_float8.h"
20
#include "opus/opus.hpp"
Xiaowei.zhang's avatar
Xiaowei.zhang committed
21
22
23
24
25
26
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <limits>
#include <map>
27
#include <string>
Xiaowei.zhang's avatar
Xiaowei.zhang committed
28
29
30
#include <unordered_map>
#include <vector>

31
namespace aiter {
Xiaowei.zhang's avatar
Xiaowei.zhang committed
32

33
34
35
36
constexpr int kMaxBlocks = 80;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal
Xiaowei.zhang's avatar
Xiaowei.zhang committed
37
38
39
40
{
    alignas(128) uint32_t start[kMaxBlocks][8];
    alignas(128) uint32_t end[kMaxBlocks][8];
    alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
41
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
42
43

#ifdef USE_ROCM
44
45
46
47
struct __align__(16) RankData
{
    const void* ptrs[8];
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
48
#else
49
50
51
52
struct __align__(16) RankData
{
    const void* __restrict__ ptrs[8];
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
53
54
#endif

55
56
struct __align__(16) RankSignals
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
57
58
59
#ifndef USE_ROCM
    volatile
#endif
60
61
        Signal* signals[8];
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
62
63
64

#define DINLINE __device__ __forceinline__

65
66
67
68
// scalar cast functions
template <typename inp_dtype>
DINLINE opus::fp32_t upcast_s(inp_dtype val)
{ return opus::cast<opus::fp32_t>(val); }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
template <>
DINLINE opus::fp32_t upcast_s<opus::fp32_t>(opus::fp32_t val)
{ return val; }

template <typename out_dtype>
DINLINE out_dtype downcast_s(opus::fp32_t val)
{ return opus::cast<out_dtype>(val); }

template <>
DINLINE opus::fp32_t downcast_s<opus::fp32_t>(opus::fp32_t val)
{ return val; }

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
template <typename T, int N>
DINLINE opus::vector_t<T, N>& packed_assign_add(opus::vector_t<T, N>& a, opus::vector_t<T, N> b)
{
    if constexpr(std::is_same<T, opus::fp32_t>::value)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
89
    {
90
91
92
93
94
95
96
97
98
        a += b;
    }
    else
    {
#pragma unroll
        for(int i = 0; i < N; i++)
        {
            a[i] = downcast_s<T>(upcast_s(a[i]) + upcast_s(b[i]));
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
99
100
    }
    return a;
101
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
102

103
104
105
106
107
108
109
// not support fp8 pack convert
template <typename V, std::enable_if_t<opus::is_vector_v<V>, bool> = true>
DINLINE auto upcast(V val) -> opus::vector_t<float, opus::vector_traits<V>::size()>
{
    using T         = typename opus::vector_traits<V>::dtype;
    constexpr int N = opus::vector_traits<V>::size();
    if constexpr(std::is_same<T, opus::fp32_t>::value)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
110
    {
111
        return val;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
112
113
114
    }
    else
    {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        opus::vector_t<float, N> out;
#pragma unroll
        for(int i = 0; i < N; i++)
        {
            out[i] = upcast_s(val[i]);
        }
        return out;
    }
}

template <typename O, typename V, std::enable_if_t<opus::is_vector_v<V>, bool> = true>
DINLINE O downcast(V val)
{
    using T         = typename opus::vector_traits<O>::dtype;
    constexpr int N = opus::vector_traits<O>::size();
    if constexpr(std::is_same<T, float>::value)
    {
        return val;
    }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
134
135
    else
    {
136
        O out;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
137
#pragma unroll
138
139
140
141
142
        for(int i = 0; i < N; i++)
        {
            out[i] = downcast_s<T>(val[i]);
        }
        return out;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
143
    }
144
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
145

146
147
148
149
150
151
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
152
#ifndef USE_ROCM
153
                        volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
154
#endif
155
156
157
                        Signal* self_sg,
                        int rank)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
158
159
#ifdef USE_ROCM
    uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
160
161
162
163
164
165
166
167
168
169
170
171
172
    if(threadIdx.x < ngpus)
    {
        // simultaneously write to the corresponding flag of all ranks.
        // Latency = 1 p2p write
        __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
                                flag,
                                __ATOMIC_RELAXED,
                                __MEMORY_SCOPE_SYSTEM);
        // wait until we got true from all ranks
        while(__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
                                     __ATOMIC_RELAXED,
                                     __MEMORY_SCOPE_DEVICE) < flag)
            ;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
173
174
175
    }
    __syncthreads();
    // use one thread to update flag
176
177
    if(threadIdx.x == 0)
        self_sg->_flag[blockIdx.x] = flag;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
178
#else
179
    if(threadIdx.x < ngpus)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
180
    {
181
182
183
184
185
186
187
188
        // reset flag for next time
        self_sg->end[blockIdx.x][threadIdx.x] = 0;
        // simultaneously write to the corresponding flag of all ranks.
        // Latency = 1 p2p write
        sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
        // wait until we got true from all ranks
        while(!self_sg->start[blockIdx.x][threadIdx.x])
            ;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
189
190
191
    }
    __syncthreads();
#endif
192
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
193

194
195
196
197
198
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
199
#ifndef USE_ROCM
200
                      volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
201
#endif
202
203
204
                      Signal* self_sg,
                      int rank)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
205
206
207
208
209
210
211
#ifdef USE_ROCM
    __syncthreads();
    // eliminate the case that prior writes are not visible after signals become
    // visible. Note that I did not managed to make this happen through a lot of
    // testing. Might be the case that hardware provides stronger guarantee than
    // the memory model.
    uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
212
213
214
215
216
217
218
219
220
221
222
223
224
    if(threadIdx.x < ngpus)
    {
        // simultaneously write to the corresponding flag of all ranks.
        // Latency = 1 p2p write
        __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
                                flag,
                                final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
                                __MEMORY_SCOPE_SYSTEM);
        // wait until we got true from all ranks
        while(__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                                     final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
                                     __MEMORY_SCOPE_DEVICE) < flag)
            ;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
225
226
227
    }
    __syncthreads();
    // use one thread to update flag
228
229
    if(threadIdx.x == 0)
        self_sg->_flag[blockIdx.x] = flag;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
230
231
232
233
234
235
#else
    __syncthreads();
    // eliminate the case that prior writes are not visible after signals become
    // visible. Note that I did not managed to make this happen through a lot of
    // testing. Might be the case that hardware provides stronger guarantee than
    // the memory model.
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    if constexpr(!final_sync)
        __threadfence_system();
    if(threadIdx.x < ngpus)
    {
        // reset flag for next time
        self_sg->start[blockIdx.x][threadIdx.x] = 0;
        // simultaneously write to the corresponding flag of all ranks.
        // Latency = 1 p2p write
        sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
        // wait until we got true from all ranks
        while(!self_sg->end[blockIdx.x][threadIdx.x])
            ;
    }
    if constexpr(!final_sync)
        __syncthreads();
Xiaowei.zhang's avatar
Xiaowei.zhang committed
251
#endif
252
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
253

254
255
256
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
257
258
    A tmp = upcast(ptrs[0][idx]);
#pragma unroll
259
    for(int i = 1; i < ngpus; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
260
    {
261
262
        packed_assign_add<typename opus::vector_traits<A>::dtype, opus::vector_traits<A>::size()>(
            tmp, upcast(ptrs[i][idx]));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
263
264
    }
    return downcast<P>(tmp);
265
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
266

267
268
269
270
template <typename T, int ngpus, bool is_broadcast_reg_outptr = false>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage_naive(RankData* _input_dp,
                                                                           RankData* _output_dp,
                                                                           RankSignals sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
271
#ifndef USE_ROCM
272
                                                                           volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
273
#endif
274
275
276
277
278
279
280
281
                                                                           Signal* self_sg,
                                                                           T* __restrict__ result,
                                                                           int rank,
                                                                           int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
282
283
    // note: we don't reorder the address so the accumulation order is the same
    // for all ranks, ensuring bitwise identical results
284
    auto dp = *_input_dp;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
285
286
    start_sync<ngpus>(sg, self_sg, rank);
    // do the actual reduction
287
    for(int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
288
    {
289
        ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
290
291
    }
    end_sync<ngpus, true>(sg, self_sg, rank);
292
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
293

294
template <typename P>
Xiaowei.zhang's avatar
Xiaowei.zhang committed
295
#ifdef USE_ROCM
296
297
DINLINE P* get_tmp_buf(Signal* sg)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
298
#else
299
300
DINLINE P* get_tmp_buf(volatile Signal* sg)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
301
#endif
302
303
    return (P*)(((Signal*)sg) + 1);
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
304

305
306
307
308
template <typename T, int ngpus, bool is_broadcast_reg_outptr = false>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage_naive(RankData* _input_dp,
                                                                           RankData* _output_dp,
                                                                           RankSignals sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
309
#ifndef USE_ROCM
310
                                                                           volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
311
#endif
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
                                                                           Signal* self_sg,
                                                                           T* __restrict__ result,
                                                                           int rank,
                                                                           int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    int tid                 = blockIdx.x * blockDim.x + threadIdx.x;
    int stride              = gridDim.x * blockDim.x;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    int part                = size / ngpus;
    int start               = rank * part;
    int end                 = rank == ngpus - 1 ? size : start + part;
    int largest_part        = part + size % ngpus;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; i++)
    {
        int target = (rank + i) % ngpus;
        ptrs[i]    = (const P*)_input_dp->ptrs[target];
        tmps[i]    = get_tmp_buf<P>(sg.signals[target]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
334
335
336
337
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
338
    for(int idx = start + tid; idx < end; idx += stride)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
339
    {
340
        tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
341
342
343
344
345
346
347
348
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: allgather. Note: it's important to match the tid between
    // the two stages, because visibility across devices is only guaranteed
    // between threads that have the same tid. If thread i computes the sum of
    // start + i in the first stage, then thread i also gathers start + i from all
    // ranks.
349
    for(int idx = tid; idx < largest_part; idx += stride)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
350
351
    {
#pragma unroll
352
        for(int i = 0; i < ngpus; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
353
        {
354
355
356
357
358
359
            int gather_from_rank = ((rank + i) % ngpus);
            if(gather_from_rank == ngpus - 1 || idx < part)
            {
                int dst_idx           = gather_from_rank * part + idx;
                ((P*)result)[dst_idx] = tmps[i][idx];
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
360
361
        }
    }
362
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
363
364
365

#define THREAD_NUM 512

366
367
368
369
template <typename T, int ngpus, bool is_broadcast_reg_outptr = false>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _input_dp,
                                                                     RankData* _output_dp,
                                                                     RankSignals sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
370
#ifndef USE_ROCM
371
                                                                     volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
372
#endif
373
374
375
376
377
378
379
380
381
                                                                     Signal* self_sg,
                                                                     T* __restrict__ result,
                                                                     int rank,
                                                                     int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;

Xiaowei.zhang's avatar
Xiaowei.zhang committed
382
383
384
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    // note: we don't reorder the address so the accumulation order is the same
    // for all ranks, ensuring bitwise identical results
385
    auto dp     = *_input_dp;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
386
387
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
388
389
390
391
392
393
394

    // --- double buffer: tmp_smem[0] and tmp_smem[1] ---
    __shared__ P tmp_smem[2][tnum_gpu * ngpus];

    const int step  = gridDim.x * tnum_gpu;
    const int start = blockIdx.x * tnum_gpu + lane_id;

Xiaowei.zhang's avatar
Xiaowei.zhang committed
395
    start_sync<ngpus>(sg, self_sg, rank);
396
397
398
399

    // --- compute uniform iteration count (to keep barriers well-formed) ---
    const int first = blockIdx.x * tnum_gpu;
    int iters       = 0;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
400
    {
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        int rem = size - first;
        iters   = rem > 0 ? (rem + step - 1) / step : 0;
    }

    // -------------------------------
    // fill buffer 0
    // -------------------------------
    int buf  = 0;
    int idx0 = start;

    if(idx0 < size)
    {
        P val                                       = ((const P**)&dp.ptrs[0])[warp_id][idx0];
        tmp_smem[buf][warp_id * tnum_gpu + lane_id] = val;
    }
    __syncthreads();

    for(int it = 0; it < iters; ++it)
    {
        const int cur_idx  = idx0 + it * step;
        const int next_idx = cur_idx + step;
        const int next_buf = buf ^ 1;

        // =======================================================
        // 1. Warp 0 REDUCES current buffer
        // =======================================================
        if(warp_id == 0 && cur_idx < size)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
428
        {
429
430
431
432
            // GPU 0 contribution
            P v0 = tmp_smem[buf][0 * tnum_gpu + lane_id];

            A acc;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
433
#pragma unroll
434
435
436
437
            for(int j = 0; j < pack_size; ++j)
                acc[j] = upcast_s(v0[j]);

            // GPUs 1..(ngpus-1)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
438
#pragma unroll
439
440
441
442
443
444
445
446
447
448
            for(int g = 1; g < ngpus; ++g)
            {
                P vg = tmp_smem[buf][g * tnum_gpu + lane_id];
#pragma unroll
                for(int j = 0; j < pack_size; ++j)
                    acc[j] += upcast_s(vg[j]);
            }

            // store result
            P out;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
449
#pragma unroll
450
451
452
453
454
455
456
457
458
459
460
            for(int j = 0; j < pack_size; ++j)
                out[j] = downcast_s<T>(acc[j]);

            ((P*)result)[cur_idx] = out;
        }

        // =======================================================
        // 2. ALL warps prefetch NEXT buffer
        //    (including warp 0; safe to issue after reduction)
        // =======================================================
        if(next_idx < size)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
461
        {
462
463
            P nxt = ((const P**)&dp.ptrs[0])[warp_id][next_idx];
            tmp_smem[next_buf][warp_id * tnum_gpu + lane_id] = nxt;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
464
        }
465
466
467
468

        __syncthreads();

        buf = next_buf;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
469
    }
470
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
471

472
473
474
475
template <typename T, int ngpus, bool is_broadcast_reg_outptr = false>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _input_dp,
                                                                     RankData* _output_dp,
                                                                     RankSignals sg,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
476
#ifndef USE_ROCM
477
                                                                     volatile
Xiaowei.zhang's avatar
Xiaowei.zhang committed
478
#endif
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
                                                                     Signal* self_sg,
                                                                     T* __restrict__ result,
                                                                     int rank,
                                                                     int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    constexpr int tnum_gpu  = THREAD_NUM / ngpus;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    int warp_id             = threadIdx.x / tnum_gpu;
    int lane_id             = threadIdx.x % tnum_gpu;
    int tid                 = blockIdx.x * tnum_gpu + lane_id;
    int stride              = gridDim.x * tnum_gpu;
    int part                = size / ngpus;
    int start               = rank * part;
    int end                 = rank == ngpus - 1 ? size : start + part;
    int largest_part        = part + size % ngpus;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
496
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
497
498
    const P* ptrs[ngpus];
    P* tmps[ngpus];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
499
#pragma unroll
500
    for(int i = 0; i < ngpus; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
501
    {
502
503
504
        int target = (rank + i) % ngpus;
        ptrs[i]    = (const P*)_input_dp->ptrs[target];
        tmps[i]    = get_tmp_buf<P>(sg.signals[target]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
505
506
507
508
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
509
    for(int idx = start + tid; idx < end; idx += stride)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
510
    {
511
512
513
514
        *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx];
        __syncthreads();
        // cal add in first 64 threads
        if(warp_id == 0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
515
        {
516
            A add_reg;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
517
#pragma unroll
518
519
520
521
522
            for(int i = 0; i < pack_size; ++i)
            {
                add_reg[i] = upcast_s(tmp_smem[pack_size * threadIdx.x + i]);
            }
            constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
523
#pragma unroll
524
525
            for(int i = 1; i < ngpus; ++i)
            {
Xiaowei.zhang's avatar
Xiaowei.zhang committed
526
#pragma unroll
527
528
529
530
531
532
533
534
535
536
537
538
539
                for(int j = 0; j < pack_size; ++j)
                {
                    add_reg[j] +=
                        upcast_s(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]);
                }
            }
            P write_reg;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                write_reg[i] = downcast_s<T>(add_reg[i]);
            }
            tmp_out[idx - start] = write_reg;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
540
        }
541
        __syncthreads();
Xiaowei.zhang's avatar
Xiaowei.zhang committed
542
543
544
545
546
547
548
549
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: allgather. Note: it's important to match the tid between
    // the two stages, because visibility across devices is only guaranteed
    // between threads that have the same tid. If thread i computes the sum of
    // start + i in the first stage, then thread i also gathers start + i from all
    // ranks.
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    for(int idx = tid; idx < largest_part; idx += stride)
    {
        int dst_idx           = (warp_id + rank) % ngpus * part + idx;
        ((P*)result)[dst_idx] = tmps[warp_id][idx];
    }
}

template <typename T, int ngpus, bool is_broadcast_reg_outptr = false>
__global__ void __launch_bounds__(512, 1)
    cross_device_reduce_2stage_write_mode(RankData* _input_dp,
                                          RankData* _output_dp,
                                          RankSignals sg,
#ifndef USE_ROCM
                                          volatile
#endif
                                          Signal* self_sg,
                                          T* __restrict__ result,
                                          int rank,
                                          int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    constexpr int tnum_gpu  = THREAD_NUM / ngpus;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
    __shared__ T res_smem[tnum_gpu * pack_size];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
576
577
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    int tid     = blockIdx.x * tnum_gpu + lane_id;
    int stride  = gridDim.x * tnum_gpu;
    int part    = size / ngpus;
    P* output_ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; i++)
    {
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    if(is_broadcast_reg_outptr)
    {
#pragma unroll
        for(int i = 0; i < ngpus; i++)
        {
            output_ptrs[i] = (P*)_output_dp->ptrs[i];
        }
    }
    const P* input_ptr = (const P*)_input_dp->ptrs[rank];
    auto tmp_out       = tmps[rank];
    int stage3_offset  = size;

    // stage1: write local rank data to remote rank
    int start = warp_id * part;
    int end   = warp_id == ngpus - 1 ? size : start + part;
    for(int idx = start + tid; idx < end; idx += stride)
    {
        tmps[warp_id][rank * part + idx - start] = input_ptr[idx];
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: reduce scatter & write result to remote rank
    end = rank != ngpus - 1 ? part : size - part * (ngpus - 1);
    for(int idx = tid; idx < end; idx += stride)
    {
        *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = tmp_out[warp_id * part + idx];
        __syncthreads();
        // cal add in first 64 threads
        if(warp_id == 0)
        {
            A add_reg;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                add_reg[i] = upcast_s(tmp_smem[pack_size * threadIdx.x + i]);
            }
            constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size;
#pragma unroll
            for(int i = 1; i < ngpus; ++i)
            {
#pragma unroll
                for(int j = 0; j < pack_size; ++j)
                {
                    add_reg[j] +=
                        upcast_s(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]);
                }
            }
            P write_reg;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                write_reg[i] = downcast_s<T>(add_reg[i]);
            }
            *(reinterpret_cast<P*>(&res_smem[0]) + lane_id) = write_reg;
        }
        __syncthreads();
        // send data to remote rank
        if(is_broadcast_reg_outptr)
        {
            P temp_val    = *(reinterpret_cast<P*>(&res_smem[0]) + lane_id);
            auto src_addr = (reinterpret_cast<int*>(&temp_val));
            auto dst_addr = (reinterpret_cast<int*>(&output_ptrs[warp_id][rank * part + idx]));
            __builtin_nontemporal_store(*src_addr, dst_addr);
            __builtin_nontemporal_store(*(src_addr + 1), dst_addr + 1);
            __builtin_nontemporal_store(*(src_addr + 2), dst_addr + 2);
            __builtin_nontemporal_store(*(src_addr + 3), dst_addr + 3);
        }
        else
        {
            tmps[warp_id][rank * part + idx + stage3_offset] =
                *(reinterpret_cast<P*>(&res_smem[0]) + lane_id);
        }
    }
    end_sync<ngpus>(sg, self_sg, rank);

    if(!is_broadcast_reg_outptr)
    {
        // stage 3: get the output from tmp_buffer
        end = warp_id == ngpus - 1 ? size : start + part;
        for(int idx = start + tid; idx < end; idx += stride)
        {
            ((P*)result)[idx] = tmp_out[idx + stage3_offset];
        }
    }
}

/*
 * naive allgather
 * for case: input(1345,)
 * */
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) allgather_naive(
    RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size)
{
    constexpr int tnum_gpu = THREAD_NUM / ngpus;
    int warp_id            = threadIdx.x / tnum_gpu;
    int lane_id            = threadIdx.x % tnum_gpu;
    int tid                = blockIdx.x * tnum_gpu + lane_id;
    int stride             = gridDim.x * tnum_gpu;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
687
688
689
    const T* ptrs[ngpus];

#pragma unroll
690
    for(int i = 0; i < ngpus; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
691
    {
692
        ptrs[i] = (const T*)_dp->ptrs[i];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
693
694
695
    }
    start_sync<ngpus>(sg, self_sg, rank);

696
    for(int idx = tid; idx < size; idx += stride)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
697
    {
698
699
        int write_idx     = warp_id * size + idx;
        result[write_idx] = ptrs[warp_id][idx];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
700
    }
701
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
702

703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) allgather_vec(
    RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size)
{
    constexpr int tnum_gpu  = THREAD_NUM / ngpus;
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    int warp_id             = threadIdx.x / tnum_gpu;
    int lane_id             = threadIdx.x % tnum_gpu;
    int tid                 = blockIdx.x * tnum_gpu + lane_id;
    int stride              = gridDim.x * tnum_gpu;
    const P* ptrs[ngpus];

#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
    }
    start_sync<ngpus>(sg, self_sg, rank);

    for(int idx = tid; idx < size; idx += stride)
    {
        int write_idx                                   = warp_id * size + idx;
        *(reinterpret_cast<P*>(&result[0]) + write_idx) = ptrs[warp_id][idx];
    }
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) allgather_lastdim(RankData* _dp,
                                                            RankSignals sg,
                                                            Signal* self_sg,
                                                            T* __restrict__ result,
                                                            int rank,
                                                            int size,
                                                            int last_dim_size)
{
    constexpr int tnum_gpu  = THREAD_NUM / ngpus;
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    int warp_id             = threadIdx.x / tnum_gpu;
    int lane_id             = threadIdx.x % tnum_gpu;
    int tid                 = blockIdx.x * tnum_gpu + lane_id;
    int stride              = gridDim.x * tnum_gpu;

    last_dim_size /= pack_size;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
748
749
750
    const P* ptrs[ngpus];

#pragma unroll
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
    }
    start_sync<ngpus>(sg, self_sg, rank);

    for(int idx = tid; idx < size; idx += stride)
    {
        int y                                           = idx / last_dim_size;
        int x                                           = idx % last_dim_size;
        int write_idx                                   = (ngpus * y + warp_id) * last_dim_size + x;
        *(reinterpret_cast<P*>(&result[0]) + write_idx) = ptrs[warp_id][idx];
    }
}

/*
 * reduce_scatter, at first dim
 * range = size / (pack_size * ngpu)
 * for case:
 *  input:(ngpus * n) -> output:(n)
 *  input:(ngpus * m, n, ...) -> output(m, n, ...)
 * cond: size % (pack_size * ngpus) == 0
 * */
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) reduce_scatter_first_dim(
    RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int range)
{
    int tid                 = blockIdx.x * blockDim.x + threadIdx.x;
    int stride              = blockDim.x * gridDim.x;
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    const P* ptrs[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
786
    {
787
788
        int target = (rank + i) % ngpus;
        ptrs[i]    = (const P*)_dp->ptrs[target];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
789
790
791
    }
    start_sync<ngpus>(sg, self_sg, rank);

792
    for(int idx = tid; idx < range; idx += stride)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
793
    {
794
795
796
797
        int load_index  = rank * range + idx;
        int store_index = idx;
        *(reinterpret_cast<P*>(result) + store_index) =
            packed_reduce<P, ngpus, A>(ptrs, load_index);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
798
    }
799
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
800

801
802
803
804
// fp8 quant all-reduce code start
template <typename T>
struct Fp16Filter
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
805
    static const bool value = false;
806
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
807

808
809
810
template <>
struct Fp16Filter<opus::fp16_t>
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
811
    static const bool value = true;
812
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
813

814
815
816
template <typename T>
struct Bf16Filter
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
817
    static const bool value = false;
818
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
819

820
821
822
template <>
struct Bf16Filter<opus::bf16_t>
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
823
    static const bool value = true;
824
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
825

826
827
// dtypes only support half and bf16 now
#define FP16_FILTER typename std::enable_if<Fp16Filter<T>::value, void>::type* = nullptr
Xiaowei.zhang's avatar
Xiaowei.zhang committed
828

829
#define BF16_FILTER typename std::enable_if<Bf16Filter<T>::value, void>::type* = nullptr
Xiaowei.zhang's avatar
Xiaowei.zhang committed
830

831
832
833
834
835
template <template <typename> class functor, typename T, int size>
DINLINE T packReduce(opus::vector_t<T, size> pack)
{
    auto op   = functor<T>();
    T ret_val = pack[0];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
836
#pragma unroll
837
    for(int i = 1; i < size; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
838
    {
839
        ret_val = op(ret_val, pack[i]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
840
841
    }
    return ret_val;
842
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
843

844
845
846
template <template <typename> class functor, typename T, int size>
DINLINE opus::vector_t<T, size> packOp(opus::vector_t<T, size> a, opus::vector_t<T, size> b)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
847
    auto op = functor<T>();
848
    opus::vector_t<T, size> ret_pack;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
849
#pragma unroll
850
    for(int i = 0; i < size; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
851
    {
852
        ret_pack[i] = op(a[i], b[i]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
853
854
    }
    return ret_pack;
855
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
856

857
858
859
860
template <typename T>
struct AddFunctor
{
    DINLINE T operator()(T a, T b)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
861
    {
862
863
864
        opus::fp32_t a_fp32 = upcast_s(a);
        opus::fp32_t b_fp32 = upcast_s(b);
        return downcast_s<T>(a_fp32 + b_fp32);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
865
    }
866
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
867

868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
template <>
struct AddFunctor<opus::fp32_t>
{
    DINLINE opus::fp32_t operator()(opus::fp32_t a, opus::fp32_t b) { return a + b; }
};

// MLA metadata used this specialisation
template <>
struct AddFunctor<int>
{
    DINLINE int operator()(int a, int b) { return a + b; }
};

template <typename T>
struct MaxFunctor
{
    DINLINE T operator()(T a, T b) { return max(a, b); }
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
886

887
888
889
890
891
892
893
894
895
896
897
/*
 * todo:
 * static_cast may not safe
 * need a convert dtype template function defined by myself
 *
 * done
 * */
template <typename T>
struct AbsMaxFunctor
{
    DINLINE T operator()(T a, T b)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
898
    {
899
900
901
902
        T zero_t = downcast_s<T>(0.0f);
        a        = a > zero_t ? a : zero_t - a;
        b        = b > zero_t ? b : zero_t - b;
        return max(a, b);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
903
    }
904
};
Xiaowei.zhang's avatar
Xiaowei.zhang committed
905

906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
// cross-lane butterfly shuffle (XOR) via ds_bpermute
template<typename T>
DINLINE T shfl_xor(T var, int mask, int width = opus::get_warp_size())
{
    static_assert(sizeof(T) == 4); 
    int self = opus::lane_id();
    int index = (self & ~(width - 1)) + ((self ^ mask) & (width - 1));
    return __builtin_bit_cast(T, __builtin_amdgcn_ds_bpermute(index << 2, __builtin_bit_cast(int, var)));
}

// shfl_xor support 4bytes dtype only
template <template <typename> class functor, typename T, int reduce_range, int stop_stride = 0>
DINLINE T warpReduce(T val)
{
    if constexpr (sizeof(T) == 4)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
921
    {
922
923
924
925
926
927
928
        auto op = functor<T>();
#pragma unroll
        for(int stride = reduce_range / 2; stride > stop_stride; stride >>= 1)
        {
            T tmp = shfl_xor(val, stride, reduce_range);
            val   = op(val, tmp);
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
929
    }
930
    else
Xiaowei.zhang's avatar
Xiaowei.zhang committed
931
    {
932
933
934
935
936
937
938
939
940
        auto op = functor<float>();
        float val_fp32 = upcast_s(val);
#pragma unroll
        for(int stride = reduce_range / 2; stride > stop_stride; stride >>= 1)
        {
            float tmp = shfl_xor(val_fp32, stride, reduce_range);
            val_fp32  = op(val_fp32, tmp);
        }
        val = downcast_s<T>(val_fp32);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
941
    }
942
943
    return val;
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
944

945
946
947
948
// Runtime reduce_range version for non-compile-time-known block sizes
template <template <typename> class functor, typename T>
DINLINE T warpReduceRuntime(T val, int reduce_range)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
949
    auto op = functor<T>();
950
    for(int stride = reduce_range / 2; stride > 0; stride >>= 1)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
951
    {
952
953
        T tmp = shfl_xor(val, stride, reduce_range);
        val   = op(val, tmp);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
954
955
    }
    return val;
956
957
958
959
960
961
962
963
964
965
}

// the following code only support bf16 and fp16
// pack_size must be divisible by 4
// TODO: check if pack_size is divisible by 4
template <typename T, int pack_size>
DINLINE opus::vector_t<opus::fp8_t, pack_size> packQuant(opus::vector_t<T, pack_size> inp_pack,
                                                         T scale_functor)
{
    opus::vector_t<opus::fp8_t, pack_size> ret_val;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
966
#pragma unroll
967
    for(int i = 0; i < pack_size / 4; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
968
    {
969
970
971
972
973
974
975
976
        opus::fp32x4_t tmp;
#pragma unroll
        for(int j = 0; j < 4; ++j)
        {
            tmp[j] = upcast_s(inp_pack[i * 4 + j]);
        }
        *(reinterpret_cast<opus::fp8x4_t*>(&ret_val) + i) =
            opus::cast<opus::fp8_t>(tmp / upcast_s(scale_functor));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
977
978
    }
    return ret_val;
979
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
980

981
982
983
984
985
template <typename T, int pack_size>
DINLINE opus::vector_t<T, pack_size> packDequant(opus::vector_t<opus::fp8_t, pack_size> inp_pack,
                                                 T scale_functor)
{
    opus::vector_t<T, pack_size> ret_val;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
986
#pragma unroll
987
    for(int i = 0; i < pack_size / 4; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
988
    {
989
990
991
992
993
994
995
996
        opus::fp32x4_t tmp =
            opus::cast<opus::fp32_t>(*(reinterpret_cast<opus::fp8x4_t*>(&inp_pack) + i));
        tmp *= upcast_s(scale_functor);
#pragma unroll
        for(int j = 0; j < 4; ++j)
        {
            ret_val[i * 4 + j] = downcast_s<T>(tmp[j]);
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
997
998
    }
    return ret_val;
999
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1000

1001
1002
1003
1004
1005
template <typename T, int pack_size, int ngpus>
DINLINE opus::vector_t<T, pack_size>
multiGPUPackReduce(const opus::vector_t<T, pack_size>* ptrs[ngpus], int index)
{
    opus::vector_t<opus::fp32_t, pack_size> ret_val = upcast(ptrs[0][index]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1006
#pragma unroll
1007
    for(int gpu_id = 1; gpu_id < ngpus; ++gpu_id)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1008
    {
1009
1010
        opus::vector_t<opus::fp32_t, pack_size> tmp = upcast(ptrs[gpu_id][index]);
        ret_val += tmp;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1011
    }
1012
1013
    return downcast<opus::vector_t<T, pack_size>>(ret_val);
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1014

1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
// bf16 quant fp8 kernel function
// too slow need to be optimized
// fp16
template <typename T, int quant_scale, int pack_size, int ngpus, FP16_FILTER>
__global__ __forceinline__ void __launch_bounds__(512, 1) allReduceQuantFp8(
    RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size)
{
    float FP8_UPBOUND = opus::cast<opus::fp32_t>(opus::numeric_limits<opus::fp8_t>::max());
    int tid           = blockIdx.x * blockDim.x + threadIdx.x;
    int stride        = gridDim.x * blockDim.x;
    using inp_pack    = opus::vector_t<T, pack_size>;
    using fp8_pack    = opus::vector_t<opus::fp8_t, pack_size>;
    int part          = size / ngpus;
    int start         = rank * part;
    int end           = rank == ngpus - 1 ? size : start + part;
    int largest_part  = part + size % ngpus;
    const inp_pack* ptrs[ngpus];
    fp8_pack* tmps[ngpus];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1033
#pragma unroll
1034
    for(int i = 0; i < ngpus; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1035
    {
1036
1037
1038
        int target = (rank + i) % ngpus;
        ptrs[i]    = (const inp_pack*)_dp->ptrs[target];
        tmps[i]    = get_tmp_buf<fp8_pack>(sg.signals[target]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1039
1040
1041
1042
    }
    auto tmp_out = tmps[0];
    start_sync<ngpus>(sg, self_sg, rank);
    // stage 1: reduce scatter
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
    for(int idx = start + tid; idx < end; idx += stride)
    {
        inp_pack half8_reg;
        // half8_reg = packed_reduce<P, ngpus, A>(ptrs, idx);
        half8_reg                = multiGPUPackReduce<T, pack_size, ngpus>(ptrs, idx);
        ((inp_pack*)result)[idx] = half8_reg;
        // quant
        T thread_max         = packReduce<AbsMaxFunctor, T, pack_size>(half8_reg);
        thread_max           = warpReduce<MaxFunctor, T, quant_scale / pack_size>(thread_max);
        T scale_factor       = downcast_s<T>(upcast_s(thread_max) / FP8_UPBOUND);
        tmp_out[idx - start] = packQuant<T, pack_size>(half8_reg, scale_factor);
        if(threadIdx.x % (quant_scale / pack_size) == 0)
        {
            *(reinterpret_cast<T*>(&tmp_out[part]) + (idx - start) / (quant_scale / pack_size)) =
                scale_factor;
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1059
1060
1061
1062
    }
    end_sync<ngpus>(sg, self_sg, rank);

    // stage 2: all-gather
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    for(int idx = tid; idx < largest_part; idx += stride)
    {
#pragma unroll
        for(int i = 1; i < ngpus; i++)
        {
            int gather_from_rank = ((rank + i) % ngpus);
            if(gather_from_rank == ngpus - 1 || idx < part)
            {
                // dequant
                T scale_factor;
                int factor_stride = quant_scale / pack_size;
                if(threadIdx.x % factor_stride == 0)
                {
                    scale_factor = *(reinterpret_cast<T*>(&tmps[i][part]) + idx / factor_stride);
                }
                float scale_factor_fp32 = upcast_s(scale_factor);
                scale_factor_fp32 = opus::shfl(scale_factor_fp32, (threadIdx.x / factor_stride) * factor_stride);
                scale_factor = downcast_s<T>(scale_factor_fp32);
                inp_pack half8_reg = packDequant<T, pack_size>(tmps[i][idx], scale_factor);
                int dst_idx        = gather_from_rank * part + idx;
                ((inp_pack*)result)[dst_idx] = half8_reg;
            }
        }
    }
}

// fused allreduce rmsnorm first step
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) reduce_scatter_cross_device_store(
    RankData* _dp, RankSignals sg, Signal* self_sg, int rank, int size)
{
    constexpr int pack_size = 16 / sizeof(T);
    constexpr int tnum_gpu  = THREAD_NUM / ngpus;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1098
1099
1100
    __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size];
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
1101
    int tid     = blockIdx.x * tnum_gpu + lane_id;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1102
1103
1104
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
1105
    for(int i = 0; i < ngpus; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1106
    {
1107
1108
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1109
1110
1111
1112
    }
    start_sync<ngpus>(sg, self_sg, rank);

    int part = size / (pack_size * ngpus);
1113
1114
1115
1116
1117
1118
1119
1120
    for(int idx = tid; idx < part; idx += gridDim.x * tnum_gpu)
    {
        // cross device read by all warp
        P input_reg                                         = ptrs[warp_id][rank * part + idx];
        *(reinterpret_cast<P*>(&tmp_smem[0]) + threadIdx.x) = input_reg;
        __syncthreads();
        // calculate and save in first warp
        if(warp_id == 0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1121
        {
1122
            A add_reg;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1123
#pragma unroll
1124
1125
1126
1127
            for(int i = 0; i < pack_size; ++i)
            {
                add_reg[i] = upcast_s(tmp_smem[pack_size * threadIdx.x + i]);
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1128
#pragma unroll
1129
1130
            for(int i = 1; i < ngpus; ++i)
            {
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1131
#pragma unroll
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
                for(int j = 0; j < pack_size; ++j)
                {
                    add_reg[j] +=
                        upcast_s(tmp_smem[i * pack_size * tnum_gpu + pack_size * threadIdx.x + j]);
                }
            }
            P add_rslt;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                add_rslt[i] = downcast_s<T>(add_reg[i]);
            }
            *(reinterpret_cast<P*>(&tmp_smem[0]) + lane_id) = add_rslt;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1145
        }
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
        __syncthreads();

        // cross device store
        P rslt                           = *(reinterpret_cast<P*>(&tmp_smem[0]) + lane_id);
        tmps[warp_id][rank * part + idx] = rslt;
    }
    // NOTE: must use final_sync=false (RELEASE/ACQUIRE) here. Stage 2
    // (local_device_load_rmsnorm*) on each rank reads `tmps` on the
    // rank's own memory which contains IPC writes from peer ranks'
    // stage 1 kernels. With final_sync=true (RELAXED) those cross-device
    // writes are not guaranteed to be visible even after we observe the
    // peers' end flags, and the kernel produces progressively corrupted
    // output at per-rank volumes above ~1.2 MB (verified via
    // sglang/benchmark/kernels/all_reduce/repro_ar_rmsnorm_corruption.py).
    end_sync<ngpus, false>(sg, self_sg, rank);
}

template <int reduce_range>
DINLINE void smemReduceSum(float* smem_addr)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1166
1167
    // a warp executes the same instruction
#pragma unroll
1168
    for(int stride = reduce_range / 2; stride > 32; stride >>= 1)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1169
    {
1170
1171
1172
1173
1174
        if(threadIdx.x < stride)
        {
            smem_addr[threadIdx.x] += smem_addr[threadIdx.x + stride];
        }
        __syncthreads();
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1175
1176
    }
    volatile float* v_smem = &smem_addr[0];
1177
    if(threadIdx.x < 32)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1178
    {
1179
1180
1181
1182
1183
1184
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 32];
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 16];
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 8];
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 4];
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 2];
        v_smem[threadIdx.x] += v_smem[threadIdx.x + 1];
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1185
1186
    }
    __syncthreads();
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
}

/*
 * input case n dim should be divided by 4096 with dtype bf16
 * and should be divided by 2048 with dtype fp32
 * */
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1)
    local_device_load_rmsnorm_naive(RankSignals sg,
                                    T* __restrict__ residual_inp,
                                    T* __restrict__ residual_out,
                                    T* __restrict__ results,
                                    T* __restrict__ weight,
                                    float eps,
                                    int rank,
                                    int m,
                                    int n)
{
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1208
1209
1210
    __shared__ float smem[tnum];
    P* tmps = get_tmp_buf<P>(sg.signals[rank]);

1211
    for(int bid = blockIdx.x; bid < m; bid += gridDim.x)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1212
    {
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        float square_sum = 0.0f;
        A rms_inp_f32[n_loop];
        P w_arr[n_loop];
#pragma unroll
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
        {
            int read_idx        = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
            P reduce_out_pack   = tmps[read_idx];
            P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
            w_arr[n_iter] = *(reinterpret_cast<P*>(weight) + n_iter * blockDim.x + threadIdx.x);
            A reduce_pack;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1224
#pragma unroll
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
            for(int i = 0; i < pack_size; ++i)
            {
                float res_inp          = upcast_s(residual_inp_pack[i]);
                float ar_out           = upcast_s(reduce_out_pack[i]);
                float rms_inp          = res_inp + ar_out;
                rms_inp_f32[n_iter][i] = rms_inp;
                reduce_pack[i]         = rms_inp * rms_inp;
            }
            square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
        }
        smem[threadIdx.x] = square_sum;
        __syncthreads();
        smemReduceSum<tnum>(&smem[0]);
        square_sum  = smem[0];
        float denom = rsqrtf(square_sum / n + eps);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1240
#pragma unroll
1241
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1242
        {
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            P rmsnorm_rslt;
            P rmsnorm_inp;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                float x_f32     = rms_inp_f32[n_iter][i];
                float w_f32     = upcast_s(w_arr[n_iter][i]);
                rmsnorm_inp[i]  = downcast_s<T>(x_f32);
                rmsnorm_rslt[i] = downcast_s<T>(x_f32 * w_f32 * denom);
            }
            int write_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x;
            *(reinterpret_cast<P*>(results) + write_idx)      = rmsnorm_rslt;
            *(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp;
        }
    }
}

/*
 * block size can be 256 and 512
 * corresponding 2048 and 4096 elem per block
 * */
template <typename T, int tnum, int n_loop>
__global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm(RankSignals sg,
                                                                     T* __restrict__ residual_inp,
                                                                     T* __restrict__ residual_out,
                                                                     T* __restrict__ results,
                                                                     T* __restrict__ weight,
                                                                     float eps,
                                                                     int rank,
                                                                     int m,
                                                                     int n)
{
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1278
1279
1280
    __shared__ float smem[tnum];
    P* tmps = get_tmp_buf<P>(sg.signals[rank]);

1281
    for(int bid = blockIdx.x; bid < m; bid += gridDim.x)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1282
    {
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
        float square_sum = 0.0f;
        A rms_inp_f32[n_loop];
        P w_arr[n_loop];
#pragma unroll
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
        {
            if(n_iter * tnum + threadIdx.x < (n / pack_size))
            {
                int read_idx        = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
                P reduce_out_pack   = tmps[read_idx];
                P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
                w_arr[n_iter]       = *(reinterpret_cast<P*>(weight) + n_iter * tnum + threadIdx.x);
                A reduce_pack;
#pragma unroll
                for(int i = 0; i < pack_size; ++i)
                {
                    float ar_out           = upcast_s(reduce_out_pack[i]);
                    float res_inp          = upcast_s(residual_inp_pack[i]);
                    float rms_inp          = ar_out + res_inp;
                    rms_inp_f32[n_iter][i] = rms_inp;
                    reduce_pack[i]         = rms_inp * rms_inp;
                }
                square_sum += packReduce<AddFunctor, float, pack_size>(reduce_pack);
            }
        }
        smem[threadIdx.x] = square_sum;
        __syncthreads();
        smemReduceSum<tnum>(&smem[0]);
        square_sum  = smem[0];
        float denom = rsqrtf(square_sum / n + eps);
#pragma unroll
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
        {
            if(n_iter * tnum + threadIdx.x < (n / pack_size))
            {
                P rmsnorm_rslt;
                P rmsnorm_inp;
#pragma unroll
                for(int i = 0; i < pack_size; ++i)
                {
                    float x_f32     = rms_inp_f32[n_iter][i];
                    float w_f32     = upcast_s(w_arr[n_iter][i]);
                    rmsnorm_inp[i]  = downcast_s<T>(x_f32);
                    rmsnorm_rslt[i] = downcast_s<T>(x_f32 * w_f32 * denom);
                }
                int write_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x;
                *(reinterpret_cast<P*>(results) + write_idx)      = rmsnorm_rslt;
                *(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp;
            }
        }
    }
}

template <typename T, int n_loop>
__global__ void __launch_bounds__(256, 1)
    local_device_load_rmsnorm_512n(RankSignals sg,
                                   T* __restrict__ residual_inp,
                                   T* __restrict__ residual_out,
                                   T* __restrict__ results,
                                   T* __restrict__ weight,
                                   float eps,
                                   int rank,
                                   int m,
                                   int n)
{
    constexpr int pack_size = 16 / sizeof(T);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    P* tmps                 = get_tmp_buf<P>(sg.signals[rank]);
    int warp_id             = threadIdx.x / 64;
    int lane_id             = threadIdx.x % 64;
    int warp_num            = blockDim.x / 64;

    for(int bid = blockIdx.x * warp_num + warp_id; bid < m; bid += gridDim.x * warp_num)
    {
        float square_sum = 0.0f;
        A rms_inp_f32[n_loop];
        P w_arr[n_loop];
#pragma unroll
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
        {
            int read_idx        = bid * 64 * n_loop + n_iter * 64 + lane_id;
            P reduce_out_pack   = tmps[read_idx];
            P residual_inp_pack = *(reinterpret_cast<P*>(residual_inp) + read_idx);
            w_arr[n_iter]       = *(reinterpret_cast<P*>(weight) + n_iter * 64 + lane_id);
            A reduce_pack;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                float ar_out           = upcast_s(reduce_out_pack[i]);
                float res_inp          = upcast_s(residual_inp_pack[i]);
                float rms_inp          = ar_out + res_inp;
                rms_inp_f32[n_iter][i] = rms_inp;
                reduce_pack[i]         = rms_inp * rms_inp;
            }
            float tmp_sum = packReduce<AddFunctor, float, pack_size>(reduce_pack);
            square_sum += tmp_sum;
        }
        square_sum  = warpReduce<AddFunctor, float, 64>(square_sum);
        float denom = rsqrtf(square_sum / n + eps);
#pragma unroll
        for(int n_iter = 0; n_iter < n_loop; ++n_iter)
        {
            P rmsnorm_rslt;
            P rmsnorm_inp;
#pragma unroll
            for(int i = 0; i < pack_size; ++i)
            {
                float x_f32     = rms_inp_f32[n_iter][i];
                float w_f32     = upcast_s(w_arr[n_iter][i]);
                rmsnorm_inp[i]  = downcast_s<T>(x_f32);
                rmsnorm_rslt[i] = downcast_s<T>(x_f32 * w_f32 * denom);
            }
            int write_idx = bid * 64 * n_loop + n_iter * 64 + lane_id;
            *(reinterpret_cast<P*>(results) + write_idx)      = rmsnorm_rslt;
            *(reinterpret_cast<P*>(residual_out) + write_idx) = rmsnorm_inp;
        }
    }
}

template <template <typename> class functor, typename T, int WARP_SIZE = 32>
__device__ __forceinline__ T ar_fusion_epilogue_block_reduce(T val, int block_size)
{
    static __shared__ T shared[32]; // max 1024 / 32 = 32
    const int tid       = threadIdx.x;
    const int w_tid     = tid % WARP_SIZE;
    const int wid       = tid / WARP_SIZE;
    const int num_warps = block_size / WARP_SIZE;
    // round up to next power of 2 for shfl_xor correctness
    int reduce_width    = 1;
    while(reduce_width < num_warps)
        reduce_width <<= 1;
    val                 = warpReduce<functor, T, WARP_SIZE>(val);
    if(w_tid == 0)
    {
        shared[wid] = val;
    }
    __syncthreads();
    val = (w_tid < num_warps) ? shared[w_tid] : T(0);
    __syncthreads();
    val = warpReduceRuntime<functor, T>(val, reduce_width);
    return val;
}

template <typename P,
          typename A,
          typename O,
          typename OT,
          int PACK_SIZE,
          int WARP_SIZE = 32>
__device__ __forceinline__ void
ar_fusion_epilogue_rms_norm(O& out, A& in, P& weight, float eps, int hidden_dim, int block_size)
{
    __shared__ float s_val;
    float acc = 0.f;
#pragma unroll
    for(int i = 0; i < PACK_SIZE; ++i)
    {
        float v = upcast_s(in[i]);
        acc += v * v;
    }
    acc = ar_fusion_epilogue_block_reduce<AddFunctor, float, WARP_SIZE>(acc, block_size);
    if(threadIdx.x == 0)
    {
        s_val = rsqrtf(acc / hidden_dim + eps);
    }
    __syncthreads();
#pragma unroll
    for(int i = 0; i < PACK_SIZE; ++i)
    {
        float out_ = in[i] * s_val * upcast_s(weight[i]);
        out[i]     = downcast_s<OT>(out_);
    }
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1457

1458
1459
1460
1461
1462
1463
1464
1465
template <typename A, int PACK_SIZE, int WARP_SIZE = 32>
__device__ __forceinline__ float ar_fusion_epilogue_reduce_abs_max(A& data, int block_size)
{
    __shared__ float s_val;
    auto fn   = [](float a, float b) { return a > b ? a : b; };
    float acc = -1.f;
#pragma unroll
    for(int i = 0; i < PACK_SIZE; ++i)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1466
    {
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
        float v = upcast_s(data[i]);
        acc     = fn(acc, std::abs(v));
    }
    acc = ar_fusion_epilogue_block_reduce<MaxFunctor, float, WARP_SIZE>(acc, block_size);
    if(threadIdx.x == 0)
    {
        s_val = acc;
    }
    __syncthreads();
    acc = s_val;
    return acc;
}

template <typename P, typename A, typename T, typename OutT, int PACK_SIZE>
__device__ __forceinline__ void ar_fusion_epilogue(A& in,
                                                   P& weight,
                                                   int hidden_dim,
                                                   float eps,
                                                   int idx,
                                                   int tidx,
                                                   int block_size,
                                                   OutT* __restrict__ output,
                                                   float* __restrict__ scale_out,
                                                   bool active = true)
{
    if constexpr(std::is_same_v<T, OutT>)
    {
        P out;
        ar_fusion_epilogue_rms_norm<P, A, P, T, PACK_SIZE>(
            out, in, weight, eps, hidden_dim, block_size);
        if(active)
            *reinterpret_cast<P*>(output + idx) = out;
    }
    else
    {
        float FP8_UPBOUND = opus::cast<opus::fp32_t>(opus::numeric_limits<opus::fp8_t>::max());
        using OP          = opus::vector_t<OutT, PACK_SIZE>;
        OP out_quant;
        A out;
        ar_fusion_epilogue_rms_norm<P, A, A, float, PACK_SIZE>(
            out, in, weight, eps, hidden_dim, block_size);
        float amax  = ar_fusion_epilogue_reduce_abs_max<A, PACK_SIZE>(out, block_size);
        float scale = amax == 0.f ? 1.f : amax / FP8_UPBOUND;
        out_quant   = packQuant<opus::fp32_t, PACK_SIZE>(out, scale);
        if(active)
            *reinterpret_cast<OP*>(output + idx) = out_quant;
        if(threadIdx.x == 0)
            scale_out[tidx] = scale;
    }
}

// Per-group FP8 quantization epilogue.
// group_size is in elements (e.g. 128). Each group of group_size/PACK_SIZE
// consecutive threads computes its own abs-max and scale independently.
// scale_out layout: (M, hidden_dim / group_size), row-major.
template <typename P, typename A, typename T, typename OutT, int PACK_SIZE>
__device__ __forceinline__ void ar_fusion_epilogue_per_group(
    A& in,
    P& weight,
    int hidden_dim,
    float eps,
    int idx,
    int tidx,
    int block_size,
    int group_size,
    OutT* __restrict__ output,
    float* __restrict__ scale_out,
    bool active = true,
    T* __restrict__ bf16_output = nullptr)
{
    static_assert(!std::is_same_v<T, OutT>, "per-group quant requires FP8 output");
    float FP8_UPBOUND = opus::cast<opus::fp32_t>(opus::numeric_limits<opus::fp8_t>::max());
    using OP          = opus::vector_t<OutT, PACK_SIZE>;
    A out;

    // Phase 1: RMSNorm (full block reduction, same as per-token)
    ar_fusion_epilogue_rms_norm<P, A, A, float, PACK_SIZE>(
        out, in, weight, eps, hidden_dim, block_size);

    // Optionally write the pre-quantization bf16 normed output so GDN-style
    // layers that also need an unquantized view (e.g. Qwen3.5 in_proj_ba)
    // can skip the extra separate per-group quant kernel entirely.
    if(bf16_output != nullptr && active)
    {
        P bf16_pack;
#pragma unroll
        for(int i = 0; i < PACK_SIZE; ++i)
            bf16_pack[i] = downcast_s<T>(out[i]);
        *reinterpret_cast<P*>(bf16_output + idx) = bf16_pack;
    }

    // Phase 2: Per-group abs-max reduction and quantization
    int threads_per_group = group_size / PACK_SIZE;
    int group_id          = threadIdx.x / threads_per_group;
    int lane_in_group     = threadIdx.x % threads_per_group;
    int num_groups        = hidden_dim / group_size;

    // Local abs-max across this thread's pack
    auto fn   = [](float a, float b) { return a > b ? a : b; };
    float local_max = -1.f;
#pragma unroll
    for(int i = 0; i < PACK_SIZE; ++i)
    {
        float v   = upcast_s(out[i]);
        local_max = fn(local_max, std::abs(v));
    }

    // Sub-group reduction: reduce across threads_per_group threads
    // Using shfl_xor with progressively smaller strides
    for(int stride = threads_per_group / 2; stride > 0; stride >>= 1)
    {
        float other = __shfl_xor(local_max, stride, threads_per_group);
        local_max   = fn(local_max, other);
    }

    // Now local_max holds the group-wide abs-max for all threads in this group
    float scale = local_max == 0.f ? 1.f : local_max / FP8_UPBOUND;

    // Quantize with per-group scale
    OP out_quant = packQuant<opus::fp32_t, PACK_SIZE>(out, scale);
    if(active)
        *reinterpret_cast<OP*>(output + idx) = out_quant;

    // Write per-group scale: one float per group per token
    if(lane_in_group == 0 && active)
        scale_out[tidx * num_groups + group_id] = scale;
}

template <typename T, typename OutT, int ngpus>
__global__ void __launch_bounds__(1024, 1)
    allreduce_fusion_kernel_1stage(RankData* _dp,
                                   RankSignals sg,
                                   Signal* self_sg,
                                   int rank,
                                   T* __restrict__ residual_inp,
                                   T* __restrict__ residual_out,
                                   OutT* __restrict__ output,
                                   T* __restrict__ weight,
                                   float* __restrict__ scale_out,
                                   int size,
                                   int hidden_dim,
                                   float eps)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    bool active             = (int)threadIdx.x < block_size;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    int tidx                = blockIdx.x;
    int access_id_in_token  = threadIdx.x * pack_size;
    int idx                 = tidx * hidden_dim + access_id_in_token;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    start_sync<ngpus>(sg, self_sg, rank);

    A acc{};
    P vec{};
    P weight_p{};
    if(active)
    {
        vec = ptrs[0][idx / pack_size];
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            acc[v] = upcast_s(vec[v]);
        }

#pragma unroll
        for(int r = 1; r < ngpus; ++r)
        {
            vec = ptrs[r][idx / pack_size];
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
            {
                acc[v] += upcast_s(vec[v]);
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1649
        }
1650
1651
1652
1653

        // Round allreduce result to bf16 and back to f32 before adding residual,
        // matching the numerical behavior of the unfused (allreduce -> bf16 -> add residual) path.
        // Without this, the extra f32 mantissa bits cause 1-ULP divergence that compounds across layers.
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1654
#pragma unroll
1655
1656
1657
1658
1659
1660
1661
        for(int v = 0; v < pack_size; ++v)
        {
            acc[v] = upcast_s(downcast_s<T>(acc[v]));
        }

        P res = *reinterpret_cast<P*>(residual_inp + idx);

Xiaowei.zhang's avatar
Xiaowei.zhang committed
1662
#pragma unroll
1663
        for(int v = 0; v < pack_size; ++v)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1664
        {
1665
1666
1667
1668
1669
1670
1671
            acc[v] += upcast_s(res[v]);
        }

#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            vec[v] = downcast_s<T>(acc[v]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1672
        }
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716

        *reinterpret_cast<P*>(residual_out + idx) = vec;
        weight_p = *reinterpret_cast<P*>(weight + access_id_in_token);
    }
    // padded threads participate in reduction with zero acc but skip output writes
    int padded_block_size = (int)blockDim.x;
    ar_fusion_epilogue<P, A, T, OutT, pack_size>(
        acc, weight_p, hidden_dim, eps, idx, tidx, padded_block_size, output, scale_out, active);
}

// Per-group quant variant of the 1-stage fused allreduce+rmsnorm kernel.
// scale_out shape: (m, hidden_dim / group_size) instead of (m, 1).
template <typename T, typename OutT, int ngpus>
__global__ void __launch_bounds__(1024, 1)
    allreduce_fusion_kernel_1stage_per_group(RankData* _dp,
                                             RankSignals sg,
                                             Signal* self_sg,
                                             int rank,
                                             T* __restrict__ residual_inp,
                                             T* __restrict__ residual_out,
                                             OutT* __restrict__ output,
                                             T* __restrict__ weight,
                                             float* __restrict__ scale_out,
                                             int size,
                                             int hidden_dim,
                                             int group_size,
                                             float eps,
                                             T* __restrict__ bf16_output = nullptr)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    bool active             = (int)threadIdx.x < block_size;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    int tidx                = blockIdx.x;
    int access_id_in_token  = threadIdx.x * pack_size;
    int idx                 = tidx * hidden_dim + access_id_in_token;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1717
    }
1718
    start_sync<ngpus>(sg, self_sg, rank);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1719

1720
1721
1722
1723
1724
1725
1726
1727
1728
    A acc{};
    P vec{};
    P weight_p{};
    if(active)
    {
        vec = ptrs[0][idx / pack_size];
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            acc[v] = upcast_s(vec[v]);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1729

1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
#pragma unroll
        for(int r = 1; r < ngpus; ++r)
        {
            vec = ptrs[r][idx / pack_size];
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
                acc[v] += upcast_s(vec[v]);
        }

#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            acc[v] = upcast_s(downcast_s<T>(acc[v]));

        P res = *reinterpret_cast<P*>(residual_inp + idx);
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            acc[v] += upcast_s(res[v]);

#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            vec[v] = downcast_s<T>(acc[v]);

        *reinterpret_cast<P*>(residual_out + idx) = vec;
        weight_p = *reinterpret_cast<P*>(weight + access_id_in_token);
    }
    int padded_block_size = (int)blockDim.x;
    ar_fusion_epilogue_per_group<P, A, T, OutT, pack_size>(
        acc, weight_p, hidden_dim, eps, idx, tidx, padded_block_size,
        group_size, output, scale_out, active, bf16_output);
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_1stage_per_group_launcher(
    RankData* _dp, RankSignals sg, Signal* self_sg, int rank,
    T* residual_inp, T* residual_out, OutT* output, T* weight,
    float* scale_out, int size, int hidden_dim, int group_size,
    float eps, hipStream_t stream, T* bf16_output = nullptr)
{
    auto pack_size  = 16 / sizeof(T);
    int block_size  = hidden_dim / pack_size;
    int padded_size = (block_size + 31) / 32 * 32;
    int m           = size / hidden_dim;
    dim3 grid(m);
    dim3 block(padded_size);
    allreduce_fusion_kernel_1stage_per_group<T, OutT, NGPUS>
        <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank,
                                     residual_inp, residual_out,
                                     output, weight, scale_out,
                                     size, hidden_dim, group_size, eps,
                                     bf16_output);
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_1stage_launcher(RankData* _dp,
                                             RankSignals sg,
                                             Signal* self_sg,
                                             int rank,
                                             T* residual_inp,
                                             T* residual_out,
                                             OutT* output,
                                             T* weight,
                                             float* scale_out,
                                             int size,
                                             int hidden_dim,
                                             float eps,
                                             hipStream_t stream)
{
    constexpr int PACK_SIZE  = 16 / sizeof(T);
    constexpr int WARP_SIZE  = 32;
    int BLOCK_SIZE           = hidden_dim / PACK_SIZE;
    // pad to next multiple of WARP_SIZE for correct block reduction
    int LAUNCH_THREADS       = ((BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
    int token_num            = size / hidden_dim;
    if(token_num > kMaxBlocks)
        throw std::runtime_error(
            "Token number is too large for allreduce_fusion_kernel_1stage kernel");
    dim3 threadsPerBlock(LAUNCH_THREADS);
    dim3 numBlocks(token_num);
    allreduce_fusion_kernel_1stage<T, OutT, NGPUS>
        <<<numBlocks, threadsPerBlock, 0, stream>>>(_dp,
                                                    sg,
                                                    self_sg,
                                                    rank,
                                                    residual_inp,
                                                    residual_out,
                                                    output,
                                                    weight,
                                                    scale_out,
                                                    size,
                                                    hidden_dim,
                                                    eps);
}

template <typename T, int ngpus, int WARP_SIZE>
__global__ void __launch_bounds__(1024, 1)
    qknorm_allreduce_fusion_kernel_2stage(RankData* _dp,
                                          RankSignals sg,
                                          Signal* self_sg,
                                          int rank,
                                          T* __restrict__ qkv_in,
                                          T* __restrict__ q_w,
                                          T* __restrict__ k_w,
                                          T* __restrict__ q_out,
                                          T* __restrict__ k_out,
                                          T* __restrict__ v_out,
                                          int token_num,
                                          int hidden_dim_q,
                                          int hidden_dim_k,
                                          int hidden_dim_v,
                                          float eps)
{
    constexpr int pack_size = 16 / sizeof(T);
    int hidden_dim_qk       = hidden_dim_q + hidden_dim_k;
    int hidden_dim          = hidden_dim_q + hidden_dim_k + hidden_dim_v;
    bool is_q               = (int)threadIdx.x * pack_size < hidden_dim_q;
    bool is_qk              = (int)threadIdx.x * pack_size < hidden_dim_qk;
    bool is_k               = (!is_q) && (is_qk);
    bool is_q_t0            = threadIdx.x == 0;
    bool is_k_t0            = threadIdx.x == (hidden_dim_q / pack_size);
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    int tidx                = blockIdx.x;
    int access_id_in_token  = threadIdx.x * pack_size;
    int idx                 = tidx * hidden_dim + access_id_in_token;
    int wid                 = threadIdx.x / WARP_SIZE;

    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    start_sync<ngpus>(sg, self_sg, rank);

    __shared__ float smem[32];

    A acc{};
    P vec{};
    P var_vec{};
    P weight_p{};
    float sum2 = 0.0f;

    vec = ptrs[rank][idx / pack_size];

    if (is_qk) {

#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            acc[v] = upcast_s(vec[v]);
            sum2 += acc[v] * acc[v];
        }
        
        sum2 = warpReduce<AddFunctor, float, WARP_SIZE>(sum2);

        if (threadIdx.x % WARP_SIZE == 0)
            smem[wid] = sum2;
        __syncthreads();

        sum2 = 0.0f;

        if (is_q_t0) {
            for (int i = 0; i < hidden_dim_q / WARP_SIZE / pack_size; ++i) {
                sum2 += smem[i];
            }
            sum2 /= (float)hidden_dim_q;
            *reinterpret_cast<float*>(&var_vec) = sum2;
            tmps[rank][tidx * 2 + 0] = var_vec;
        } else if (is_k_t0) {
            for (int i = hidden_dim_q / WARP_SIZE / pack_size; i < hidden_dim_qk / WARP_SIZE / pack_size; ++i) {
                sum2 += smem[i];
            }
            sum2 /= (float)hidden_dim_k;
            *reinterpret_cast<float*>(&var_vec) = sum2;
            tmps[rank][tidx * 2 + 1] = var_vec;
        }

        end_sync<ngpus>(sg, self_sg, rank);

        if (is_q_t0) {
#pragma unroll
            for(int r = 1; r < ngpus; ++r)
            {
                int target = (rank + r) % ngpus;
                auto peer_vec = tmps[target][tidx * 2 + 0];
                sum2 += *reinterpret_cast<float*>(&peer_vec);
            }
            smem[0] = sum2;
        } else if (is_k_t0) {
#pragma unroll
            for(int r = 1; r < ngpus; ++r)
            {
                int target = (rank + r) % ngpus;
                auto peer_vec = tmps[target][tidx * 2 + 1];
                sum2 += *reinterpret_cast<float*>(&peer_vec);
            }
            smem[1] = sum2;
        }

        __syncthreads();

        if (is_q) {
            weight_p = *reinterpret_cast<P*>(&q_w[access_id_in_token]);
            float denom = rsqrtf(smem[0] / ngpus + eps);
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
            {
                vec[v] = downcast_s<T>(acc[v] * denom * upcast_s(weight_p[v]));
            }
            *reinterpret_cast<P*>(&q_out[tidx * hidden_dim_q + access_id_in_token]) = vec;
        } else {
            weight_p = *reinterpret_cast<P*>(&k_w[access_id_in_token - hidden_dim_q]);
            float denom = rsqrtf(smem[1] / ngpus + eps);
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
            {
                vec[v] = downcast_s<T>(acc[v] * denom * upcast_s(weight_p[v]));
            }
            *reinterpret_cast<P*>(&k_out[tidx * hidden_dim_k + access_id_in_token - hidden_dim_q]) = vec;
        }
    
    } else {
        *reinterpret_cast<P*>(&v_out[tidx * hidden_dim_v + access_id_in_token - hidden_dim_qk]) = vec;
    }
}

template <typename T , int NGPUS>
void qknorm_allreduce_fusion_kernel_2stage_launcher(RankData* _dp,
                                             RankSignals sg,
                                             Signal* self_sg,
                                             int rank,
                                             T* qkv_in,
                                             T* q_w,
                                             T* k_w,
                                             T* q_out,
                                             T* k_out,
                                             T* v_out,
                                             int token_num,
                                             int hidden_dim_q,
                                             int hidden_dim_k,
                                             int hidden_dim_v,
                                             float eps,
                                             hipStream_t stream)
{
    constexpr int PACK_SIZE      = 16 / sizeof(T);
    constexpr int WARP_SIZE      = 32;
    constexpr int WARP_WORK_SIZE = WARP_SIZE * PACK_SIZE;
    int hidden_dim               = hidden_dim_q + hidden_dim_k + hidden_dim_v;
    int BLOCK_SIZE               = hidden_dim / PACK_SIZE;
    if(token_num > kMaxBlocks)
        throw std::runtime_error(
            "Token number is too large for qknorm_allreduce_fusion_kernel_2stage kernel");
    bool valid = (hidden_dim_q % WARP_WORK_SIZE == 0) && (hidden_dim_k % WARP_WORK_SIZE == 0) && (hidden_dim_v % WARP_WORK_SIZE == 0);
    if (!valid)
        throw std::runtime_error(
            "Invalid qk hidden dim layout for qknorm_allreduce_fusion_kernel_2stage kernel");
    dim3 threadsPerBlock(BLOCK_SIZE);
    dim3 numBlocks(token_num);
    qknorm_allreduce_fusion_kernel_2stage<T, NGPUS, WARP_SIZE>
        <<<numBlocks, threadsPerBlock, 0, stream>>>(_dp,
                                                    sg,
                                                    self_sg,
                                                    rank,
                                                    qkv_in,
                                                    q_w,
                                                    k_w,
                                                    q_out,
                                                    k_out,
                                                    v_out,
                                                    token_num,
                                                    hidden_dim_q,
                                                    hidden_dim_k,
                                                    hidden_dim_v,
                                                    eps);
}

template <typename T, typename OutT, int ngpus>
__global__ void __launch_bounds__(1024, 1)
    allreduce_fusion_kernel_2stage(RankData* _dp,
                                   RankSignals sg,
                                   Signal* self_sg,
                                   int rank,
                                   T* __restrict__ residual_inp,
                                   T* __restrict__ residual_out,
                                   OutT* __restrict__ output,
                                   T* __restrict__ weight,
                                   float* __restrict__ scale_out,
                                   int size,
                                   int hidden_dim,
                                   float eps)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    int tnum_gpu            = block_size / ngpus;
    using P                 = typename opus::vector_t<T, pack_size>;
    using OP                = opus::vector_t<OutT, 16 / sizeof(T)>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    extern __shared__ char smem_buf[];
    P* tmp_smem = reinterpret_cast<P*>(smem_buf);
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    A acc;
    start_sync<ngpus>(sg, self_sg, rank);

    for(int idx = ((blockIdx.x * ngpus + rank) * tnum_gpu + lane_id) * pack_size; idx < size;
        idx += gridDim.x * ngpus * tnum_gpu * pack_size)
    {
        P vec                 = ptrs[warp_id][idx / pack_size];
        tmp_smem[threadIdx.x] = vec;
        __syncthreads();
        if(warp_id == 0)
        {
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
            {
                acc[v] = upcast_s(vec[v]);
            }
#pragma unroll
            for(int r = 1; r < ngpus; ++r)
            {
                vec = tmp_smem[r * tnum_gpu + lane_id];
#pragma unroll
                for(int v = 0; v < pack_size; ++v)
                {
                    acc[v] += upcast_s(vec[v]);
                }
            }
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
            {
                vec[v] = downcast_s<T>(acc[v]);
            }
            tmp_smem[lane_id] = vec;
        }
        __syncthreads();
        vec                            = tmp_smem[lane_id];
        tmps[warp_id][idx / pack_size] = vec;
    }

    int access_id_in_token = threadIdx.x * pack_size;
    P weight_p             = *reinterpret_cast<P*>(weight + access_id_in_token);
    end_sync<ngpus>(sg, self_sg, rank);
    for(int idx = blockIdx.x * hidden_dim + access_id_in_token, tidx = blockIdx.x; idx < size;
        idx += gridDim.x * hidden_dim, tidx += gridDim.x)
    {
        P vec = tmps[rank][idx / pack_size];
        P res = *reinterpret_cast<P*>(residual_inp + idx);
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            vec[v] += res[v];
        }
        *reinterpret_cast<P*>(residual_out + idx) = vec;
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            acc[v] = upcast_s(vec[v]);
        }
        ar_fusion_epilogue<P, A, T, OutT, pack_size>(
            acc, weight_p, hidden_dim, eps, idx, tidx, block_size, output, scale_out);
    }
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_2stage_launcher(RankData* _dp,
                                             RankSignals sg,
                                             Signal* self_sg,
                                             int rank,
                                             T* residual_inp,
                                             T* residual_out,
                                             OutT* output,
                                             T* weight,
                                             float* scale_out,
                                             int size,
                                             int hidden_dim,
                                             float eps,
                                             hipStream_t stream)
{
    constexpr int PACK_SIZE = 16 / sizeof(T);
    int BLOCK_SIZE          = hidden_dim / PACK_SIZE;
    int token_num           = size / hidden_dim;
    dim3 threadsPerBlock(BLOCK_SIZE);
    token_num = std::min(token_num, kMaxBlocks);
    dim3 numBlocks(token_num);
    size_t smem_size = BLOCK_SIZE * sizeof(typename opus::vector_t<T, PACK_SIZE>);
    allreduce_fusion_kernel_2stage<T, OutT, NGPUS>
        <<<numBlocks, threadsPerBlock, smem_size, stream>>>(_dp,
                                                            sg,
                                                            self_sg,
                                                            rank,
                                                            residual_inp,
                                                            residual_out,
                                                            output,
                                                            weight,
                                                            scale_out,
                                                            size,
                                                            hidden_dim,
                                                            eps);
}

// Per-group quant variant of the 2-stage kernel.
template <typename T, typename OutT, int ngpus>
__global__ void __launch_bounds__(1024, 1)
    allreduce_fusion_kernel_2stage_per_group(RankData* _dp,
                                             RankSignals sg,
                                             Signal* self_sg,
                                             int rank,
                                             T* __restrict__ residual_inp,
                                             T* __restrict__ residual_out,
                                             OutT* __restrict__ output,
                                             T* __restrict__ weight,
                                             float* __restrict__ scale_out,
                                             int size,
                                             int hidden_dim,
                                             int group_size,
                                             float eps,
                                             T* __restrict__ bf16_output = nullptr)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    int tnum_gpu            = block_size / ngpus;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    extern __shared__ char smem_buf[];
    P* tmp_smem = reinterpret_cast<P*>(smem_buf);
    int warp_id = threadIdx.x / tnum_gpu;
    int lane_id = threadIdx.x % tnum_gpu;
    const P* ptrs[ngpus];
    P* tmps[ngpus];
#pragma unroll
    for(int i = 0; i < ngpus; ++i)
    {
        ptrs[i] = (const P*)_dp->ptrs[i];
        tmps[i] = get_tmp_buf<P>(sg.signals[i]);
    }
    A acc;
    start_sync<ngpus>(sg, self_sg, rank);

    for(int idx = ((blockIdx.x * ngpus + rank) * tnum_gpu + lane_id) * pack_size; idx < size;
        idx += gridDim.x * ngpus * tnum_gpu * pack_size)
    {
        P vec                 = ptrs[warp_id][idx / pack_size];
        tmp_smem[threadIdx.x] = vec;
        __syncthreads();
        if(warp_id == 0)
        {
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
                acc[v] = upcast_s(vec[v]);
#pragma unroll
            for(int r = 1; r < ngpus; ++r)
            {
                vec = tmp_smem[r * tnum_gpu + lane_id];
#pragma unroll
                for(int v = 0; v < pack_size; ++v)
                    acc[v] += upcast_s(vec[v]);
            }
#pragma unroll
            for(int v = 0; v < pack_size; ++v)
                vec[v] = downcast_s<T>(acc[v]);
            tmp_smem[lane_id] = vec;
        }
        __syncthreads();
        vec                            = tmp_smem[lane_id];
        tmps[warp_id][idx / pack_size] = vec;
    }

    int access_id_in_token = threadIdx.x * pack_size;
    P weight_p             = *reinterpret_cast<P*>(weight + access_id_in_token);
    end_sync<ngpus>(sg, self_sg, rank);
    for(int idx = blockIdx.x * hidden_dim + access_id_in_token, tidx = blockIdx.x; idx < size;
        idx += gridDim.x * hidden_dim, tidx += gridDim.x)
    {
        P vec = tmps[rank][idx / pack_size];
        P res = *reinterpret_cast<P*>(residual_inp + idx);
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            vec[v] += res[v];
        *reinterpret_cast<P*>(residual_out + idx) = vec;
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            acc[v] = upcast_s(vec[v]);
        ar_fusion_epilogue_per_group<P, A, T, OutT, pack_size>(
            acc, weight_p, hidden_dim, eps, idx, tidx, block_size,
            group_size, output, scale_out, /*active=*/true, bf16_output);
    }
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_2stage_per_group_launcher(
    RankData* _dp, RankSignals sg, Signal* self_sg, int rank,
    T* residual_inp, T* residual_out, OutT* output, T* weight,
    float* scale_out, int size, int hidden_dim, int group_size,
    float eps, hipStream_t stream, T* bf16_output = nullptr)
{
    constexpr int PACK_SIZE = 16 / sizeof(T);
    int BLOCK_SIZE          = hidden_dim / PACK_SIZE;
    int token_num           = size / hidden_dim;
    dim3 threadsPerBlock(BLOCK_SIZE);
    token_num = std::min(token_num, kMaxBlocks);
    dim3 numBlocks(token_num);
    size_t smem_size = BLOCK_SIZE * sizeof(typename opus::vector_t<T, PACK_SIZE>);
    allreduce_fusion_kernel_2stage_per_group<T, OutT, NGPUS>
        <<<numBlocks, threadsPerBlock, smem_size, stream>>>(
            _dp, sg, self_sg, rank,
            residual_inp, residual_out, output, weight, scale_out,
            size, hidden_dim, group_size, eps, bf16_output);
}

template <typename T, typename OutT>
__global__ void __launch_bounds__(1024, 1)
    local_device_load_rmsnorm_quant_naive(RankSignals sg,
                                          int rank,
                                          T* __restrict__ residual_inp,
                                          T* __restrict__ residual_out,
                                          OutT* __restrict__ output,
                                          T* __restrict__ weight,
                                          float* __restrict__ scale_out,
                                          int size,
                                          int hidden_dim,
                                          float eps)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    P* tmps                 = get_tmp_buf<P>(sg.signals[rank]);
    int access_id_in_token  = threadIdx.x * pack_size;
    P weight_p              = *reinterpret_cast<P*>(weight + access_id_in_token);
    int idx                 = blockIdx.x * hidden_dim + access_id_in_token;
    int tidx                = blockIdx.x;
    {
        A acc;
        P vec = tmps[idx / pack_size];
        P res = *reinterpret_cast<P*>(residual_inp + idx);
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            vec[v] += res[v];
        }
        *reinterpret_cast<P*>(residual_out + idx) = vec;
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
        {
            acc[v] = upcast_s(vec[v]);
        }
        ar_fusion_epilogue<P, A, T, OutT, pack_size>(
            acc, weight_p, hidden_dim, eps, idx, tidx, block_size, output, scale_out);
    }
}

// Per-group quant variant of the naive local device load kernel.
template <typename T, typename OutT>
__global__ void __launch_bounds__(1024, 1)
    local_device_load_rmsnorm_quant_per_group_naive(RankSignals sg,
                                                    int rank,
                                                    T* __restrict__ residual_inp,
                                                    T* __restrict__ residual_out,
                                                    OutT* __restrict__ output,
                                                    T* __restrict__ weight,
                                                    float* __restrict__ scale_out,
                                                    int size,
                                                    int hidden_dim,
                                                    int group_size,
                                                    float eps,
                                                    T* __restrict__ bf16_output = nullptr)
{
    constexpr int pack_size = 16 / sizeof(T);
    int block_size          = hidden_dim / pack_size;
    using P                 = typename opus::vector_t<T, pack_size>;
    using A                 = typename opus::vector_t<opus::fp32_t, pack_size>;
    P* tmps                 = get_tmp_buf<P>(sg.signals[rank]);
    int access_id_in_token  = threadIdx.x * pack_size;
    P weight_p              = *reinterpret_cast<P*>(weight + access_id_in_token);
    int idx                 = blockIdx.x * hidden_dim + access_id_in_token;
    int tidx                = blockIdx.x;
    {
        A acc;
        P vec = tmps[idx / pack_size];
        P res = *reinterpret_cast<P*>(residual_inp + idx);
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            vec[v] += res[v];
        *reinterpret_cast<P*>(residual_out + idx) = vec;
#pragma unroll
        for(int v = 0; v < pack_size; ++v)
            acc[v] = upcast_s(vec[v]);
        ar_fusion_epilogue_per_group<P, A, T, OutT, pack_size>(
            acc, weight_p, hidden_dim, eps, idx, tidx, block_size,
            group_size, output, scale_out, /*active=*/true, bf16_output);
    }
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_split_per_group_launcher(RankData* _dp,
                                                      RankSignals sg,
                                                      Signal* self_sg,
                                                      int rank,
                                                      T* residual_inp,
                                                      T* residual_out,
                                                      OutT* output,
                                                      T* weight,
                                                      float* scale_out,
                                                      int size,
                                                      int hidden_dim,
                                                      int group_size,
                                                      float eps,
                                                      hipStream_t stream,
                                                      T* bf16_output = nullptr)
{
    // step 1: reduce-scatter + allgather cross device store (same as per-token)
    dim3 block(512);
    int block_num = ((size / NGPUS) + 512 - 1) / 512;
    dim3 grid(std::min(block_num, 80));
    switch(NGPUS)
    {
    case 8:
        reduce_scatter_cross_device_store<T, 8>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    case 4:
        reduce_scatter_cross_device_store<T, 4>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    case 2:
        reduce_scatter_cross_device_store<T, 2>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    default:
        throw std::runtime_error("unsupported NGPUS=" + std::to_string(NGPUS));
    }
    // step 2: local device load + rmsnorm + per-group quant (+ optional bf16 mirror)
    constexpr int PACK_SIZE = 16 / sizeof(T);
    int BLOCK_SIZE          = hidden_dim / PACK_SIZE;
    int nblocks             = size / hidden_dim;
    dim3 threadsPerBlock(BLOCK_SIZE);
    dim3 numBlocks(nblocks);
    local_device_load_rmsnorm_quant_per_group_naive<T, OutT>
        <<<numBlocks, threadsPerBlock, 0, stream>>>(
            sg, rank, residual_inp, residual_out, output, weight, scale_out,
            size, hidden_dim, group_size, eps, bf16_output);
}

template <typename T, typename OutT, int NGPUS>
void allreduce_fusion_kernel_split_launcher(RankData* _dp,
                                            RankSignals sg,
                                            Signal* self_sg,
                                            int rank,
                                            T* residual_inp,
                                            T* residual_out,
                                            OutT* output,
                                            T* weight,
                                            float* scale_out,
                                            int size,
                                            int hidden_dim,
                                            float eps,
                                            hipStream_t stream)
{
    // step 1, run reduce-scatter + allgather cross device save
    dim3 block(512);
    int block_num = ((size / NGPUS) + 512 - 1) / 512;
    dim3 grid(std::min(block_num, 80));
    switch(NGPUS)
    {
    case 8:
        reduce_scatter_cross_device_store<T, 8>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    case 4:
        reduce_scatter_cross_device_store<T, 4>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    case 2:
        reduce_scatter_cross_device_store<T, 2>
            <<<grid, block, 0, stream>>>(_dp, sg, self_sg, rank, size);
        break;
    default: throw std::runtime_error("fused allreduce rmsnorm: unsupported NGPUS=" + std::to_string(NGPUS));
    }
    // step 2, run allgather local device load + rmsnorm + quant
    constexpr int PACK_SIZE = 16 / sizeof(T);
    int BLOCK_SIZE          = hidden_dim / PACK_SIZE;
    int nblocks             = size / hidden_dim;
    dim3 threadsPerBlock(BLOCK_SIZE);
    dim3 numBlocks(nblocks);
    local_device_load_rmsnorm_quant_naive<T, OutT>
        <<<numBlocks, threadsPerBlock, 0, stream>>>(
            sg, rank, residual_inp, residual_out, output, weight, scale_out, size, hidden_dim, eps);
}

using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));

class CustomAllreduce
{
    public:
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2436
2437
2438
2439
2440
2441
    int rank_;
    int world_size_;
    bool full_nvlink_;

    // below are device pointers
    RankSignals sg_;
2442
2443
2444
    std::unordered_map<void*, RankData*> input_buffer;
    std::unordered_map<void*, RankData*> output_buffers_;
    Signal* self_sg_;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2445
2446
2447

    // stores the registered device pointers from all ranks
    RankData *d_rank_data_base_, *d_rank_data_end_;
2448
2449
    std::vector<void*> graph_unreg_input_buffers_;
    std::vector<void*> graph_unreg_output_buffers_;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2450
    // a map from IPC handles to opened IPC pointers
2451
    std::map<IPC_KEY, char*> ipc_handles_;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2452
2453

#ifdef DTK_ENV
2454
2455
    // DTK (Hygon DCU) memory-ordering helpers: an event to sequence the
    // pre-allreduce D2H flush, and a small pinned host buffer it reads into.
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2456
    hipEvent_t event_;
2457
    void*  buffer_ptr_;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
    size_t buffer_size_;
#endif

    /**
     * meta is a pointer to device metadata and temporary buffer for allreduce.
     *
     * There's a total of sizeof(Signal) of prefix before the actual data,
     * so meta + 1 points to actual temporary buffer.
     *
     * note: this class does not own any device memory. Any required buffers
     * are passed in from the constructor
     */
2470
2471
2472
2473
2474
2475
    CustomAllreduce(Signal* meta,
                    void* rank_data,
                    size_t rank_data_sz,
                    const hipIpcMemHandle_t* handles,
                    const std::vector<int64_t>& offsets,
                    int rank,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2476
2477
2478
2479
2480
                    bool fully_connected = true)
        : rank_(rank),
          world_size_(offsets.size()),
          full_nvlink_(fully_connected),
          self_sg_(meta),
2481
          d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2482
2483
          d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData))
    {
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
        for(int i = 0; i < world_size_; i++)
        {
            Signal* rank_sg;
            if(i != rank_)
            {
                char* handle = open_ipc_handle(&handles[i]);
                handle += offsets[i];
                rank_sg = (Signal*)handle;
            }
            else
            {
                rank_sg = self_sg_;
            }
            sg_.signals[i] = rank_sg;
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2499
#ifdef DTK_ENV
2500
2501
2502
        hipEventCreateWithFlags(&event_, hipEventReleaseToSystem | hipEventDisableTiming);
        buffer_size_ = 4;
        hipHostMalloc(&buffer_ptr_, buffer_size_, hipHostMallocDefault);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2503
#endif
2504
    }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2505

2506
2507
2508
2509
    char* open_ipc_handle(const void* ipc_handle)
    {
        auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
        if(new_handle)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2510
        {
2511
2512
2513
2514
2515
            char* ipc_ptr;
            HIP_CALL(hipIpcOpenMemHandle((void**)&ipc_ptr,
                                         *((const hipIpcMemHandle_t*)ipc_handle),
                                         hipIpcMemLazyEnablePeerAccess));
            it->second = ipc_ptr;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2516
        }
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
        return it->second;
    }

    std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta()
    {
        auto num_input_buffers  = graph_unreg_input_buffers_.size();
        auto num_output_buffers = graph_unreg_output_buffers_.size();
        auto num_buffers        = num_input_buffers + num_output_buffers;
        auto handle_sz          = sizeof(hipIpcMemHandle_t);
        std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
        std::vector<int64_t> offsets(num_buffers);
        for(int i = 0; i < num_input_buffers; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2529
        {
2530
2531
2532
2533
2534
            auto ptr = graph_unreg_input_buffers_[i];
            void* base_ptr;
            // note: must share the base address of each allocation, or we get wrong
            // address
            if(hipPointerGetAttribute(&base_ptr,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2535
#ifdef USE_ROCM
2536
                                      HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2537
#else
2538
                                      CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2539
#endif
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
                                      (hipDeviceptr_t)ptr) != CUDA_SUCCESS)
                throw std::runtime_error("failed to get pointer attr");
            HIP_CALL(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
            offsets[i] = ((char*)ptr) - ((char*)base_ptr);
        }

        // Process output buffers
        for(int i = 0; i < num_output_buffers; i++)
        {
            auto ptr = graph_unreg_output_buffers_[i];
            void* base_ptr;
            if(hipPointerGetAttribute(&base_ptr,
#ifdef USE_ROCM
                                      HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
                                      CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
                                      (hipDeviceptr_t)ptr) != CUDA_SUCCESS)
                throw std::runtime_error("failed to get pointer attr for output");
            HIP_CALL(hipIpcGetMemHandle(
                (hipIpcMemHandle_t*)&handles[(num_input_buffers + i) * handle_sz], base_ptr));
            offsets[num_input_buffers + i] = ((char*)ptr) - ((char*)base_ptr);
        }

        return std::make_pair(handles, offsets);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2565
2566
2567
2568
    }

    void check_rank_data_capacity(size_t num = 1)
    {
2569
2570
2571
        if(d_rank_data_base_ + num > d_rank_data_end_)
            throw std::runtime_error("Rank data buffer is overflowed by " +
                                     std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2572
2573
    }

2574
2575
2576
    void register_input_buffer(const hipIpcMemHandle_t* ipc_handles,
                               const int64_t* offsets,
                               void* self)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2577
    {
2578
2579
2580
        check_rank_data_capacity();
        RankData data;
        for(int i = 0; i < world_size_; i++)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2581
        {
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
            if(i != rank_)
            {
                char* handle = open_ipc_handle((void*)&ipc_handles[i]);
                handle += offsets[i];
                data.ptrs[i] = handle;
            }
            else
            {
                data.ptrs[i] = self;
            }
        }
        auto d_data = d_rank_data_base_++;
        HIP_CALL(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
        input_buffer[self] = d_data;
    }

    void register_output_buffer(const hipIpcMemHandle_t* ipc_handles,
                                const int64_t* offsets,
                                void* self)
    {
        check_rank_data_capacity();
        RankData data;
        for(int i = 0; i < world_size_; i++)
        {
            if(i != rank_)
            {
                char* handle = open_ipc_handle((void*)&ipc_handles[i]);
                handle += offsets[i];
                data.ptrs[i] = handle;
            }
            else
            {
                data.ptrs[i] = self;
            }
        }
        auto d_data = d_rank_data_base_++;
        HIP_CALL(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
        output_buffers_[self] = d_data;
    }

    RankData* get_buffer_RD(hipStream_t stream, void* input)
    {
        RankData* ptrs;
        auto it = input_buffer.find(input);
        if(it != input_buffer.end())
        {
            ptrs = it->second;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2629
2630
2631
        }
        else
        {
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
            hipStreamCaptureStatus status;
            HIP_CALL(hipStreamIsCapturing(stream, &status));
            if(status == hipStreamCaptureStatusActive)
            {
                ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size();
                graph_unreg_input_buffers_.push_back(input);
            }
            else
            {
                throw std::runtime_error("buffer address " +
                                         std::to_string(reinterpret_cast<uint64_t>(input)) +
                                         " is not registered!");
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2645
        }
2646
2647

        return ptrs;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2648
2649
    }

2650
    RankData* get_output_buffer_RD(hipStream_t stream, void* output)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2651
    {
2652
2653
2654
        RankData* ptrs;
        auto it = output_buffers_.find(output);
        if(it != output_buffers_.end())
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2655
        {
2656
            ptrs = it->second;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2657
2658
2659
        }
        else
        {
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
            hipStreamCaptureStatus status;
            HIP_CALL(hipStreamIsCapturing(stream, &status));
            if(status == hipStreamCaptureStatusActive)
            {
                // For graph mode, collect output addresses
                ptrs = d_rank_data_base_ + graph_unreg_input_buffers_.size() +
                       graph_unreg_output_buffers_.size();
                graph_unreg_output_buffers_.push_back(output);
            }
            else
            {
                throw std::runtime_error("output buffer address " +
                                         std::to_string(reinterpret_cast<uint64_t>(output)) +
                                         " is not registered!");
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2675
2676
        }

2677
        return ptrs;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2678
2679
2680
2681
2682
2683
2684
2685
2686
    }

    // note: when registering graph buffers, we intentionally choose to not
    // deduplicate the addresses. That means if the allocator reuses some
    // addresses, they will be registered again. This is to account for the remote
    // possibility of different allocation patterns between ranks. For example,
    // rank 1 may get the same input address for the second allreduce, but rank 2
    // got a different address. IPC handles have internal reference counting
    // mechanism so overhead should be small.
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
    void register_graph_buffers(const void* const* handles_per_rank,
                                const int64_t* const* offsets_per_rank)
    {
        auto num_input_buffers  = graph_unreg_input_buffers_.size();
        auto num_output_buffers = graph_unreg_output_buffers_.size();
        auto total_buffers      = num_input_buffers + num_output_buffers;
        check_rank_data_capacity(total_buffers);
        std::vector<RankData> rank_data(total_buffers);

        // Register input buffers
        for(int i = 0; i < num_input_buffers; i++)
        {
            auto self_ptr = graph_unreg_input_buffers_[i];
            auto& rd      = rank_data[i];
            for(int j = 0; j < world_size_; j++)
            {
                if(j != rank_)
                {
                    auto* ipc_handle_ptr =
                        (const hipIpcMemHandle_t*)handles_per_rank[j] + i;
                    char* handle = open_ipc_handle(ipc_handle_ptr);
                    handle += offsets_per_rank[j][i];
                    rd.ptrs[j] = handle;
                }
                else
                {
                    rd.ptrs[j] = self_ptr;
                }
            }
        }
        // Register output buffers
        for(int i = 0; i < num_output_buffers; i++)
        {
            auto self_ptr = graph_unreg_output_buffers_[i];
            auto& rd      = rank_data[num_input_buffers + i];
            for(int j = 0; j < world_size_; j++)
            {
                if(j != rank_)
                {
                    auto* ipc_handle_ptr =
                        (const hipIpcMemHandle_t*)handles_per_rank[j] + num_input_buffers + i;
                    char* handle = open_ipc_handle(ipc_handle_ptr);
                    handle += offsets_per_rank[j][num_input_buffers + i];
                    rd.ptrs[j] = handle;
                }
                else
                {
                    rd.ptrs[j] = self_ptr;
                }
            }
            output_buffers_[self_ptr] = d_rank_data_base_ + num_input_buffers + i;
        }

        HIP_CALL(hipMemcpy(d_rank_data_base_,
                           rank_data.data(),
                           sizeof(RankData) * total_buffers,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2743
                           hipMemcpyHostToDevice));
2744
2745
2746
        d_rank_data_base_ += total_buffers;
        graph_unreg_input_buffers_.clear();
        graph_unreg_output_buffers_.clear();
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
    }

    /*
     * call all reduce fp8 kernel
     * case size in single gpu: (128, 8192)
     * support 8 gpu only
     * should make ngpus as template param
     * should quant scale match hidden_dim when hidden_dim less than 128?
     * */
    template <typename T>
    void runFp8QuantKernel(hipStream_t stream, T* input, T* output, int size)
    {
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
        RankData* ptrs = get_buffer_RD(stream, input);
        // 32 block 512 thread or 64 block 256 thread
#define DISPATHC_UNIT(pack_size, quant_scale, ngpus)                                \
    do                                                                              \
    {                                                                               \
    case ngpus: {                                                                   \
        allReduceQuantFp8<T, quant_scale, pack_size, ngpus>                         \
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size); \
        return;                                                                     \
    }                                                                               \
    } while(0)

#define DISPATCH_CALL(pack_size, block_size, quant_scale)                                     \
    do                                                                                        \
    {                                                                                         \
        block.x = block_size;                                                                 \
        grid.x  = min((16384 / block_size), (single_device_size / (pack_size * block_size))); \
        size /= pack_size;                                                                    \
        switch(world_size_)                                                                   \
        {                                                                                     \
            DISPATHC_UNIT(pack_size, quant_scale, 2);                                         \
            DISPATHC_UNIT(pack_size, quant_scale, 4);                                         \
            DISPATHC_UNIT(pack_size, quant_scale, 6);                                         \
            DISPATHC_UNIT(pack_size, quant_scale, 8);                                         \
        }                                                                                     \
    } while(0)

        int single_device_size          = size / world_size_;
        constexpr int max_thread_num    = 512;
        constexpr int max_pack_size     = 8;
        constexpr int max_elem_perblock = max_thread_num * max_pack_size;
        dim3 grid, block;
        if(single_device_size % 128 == 0)
        {
            DISPATCH_CALL(8, 256, 128);
        }
        else if(single_device_size % 64 == 0)
        {
            DISPATCH_CALL(8, 256, 64);
        }
        else if(single_device_size % 32 == 0)
        {
            DISPATCH_CALL(8, 256, 32);
        }
        else if(single_device_size % 16 == 0)
        {
            DISPATCH_CALL(8, 256, 16);
        }
        else // 512
        {
            DISPATCH_CALL(8, 256, 8);
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
    }

    /**
     * This is the result after careful grid search. Using 36 blocks give the best
     * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
     * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
     * Not quite sure the underlying reason, but my guess is that too many SMs
     * will cause contention on NVLink bus.
     */
    template <typename T>
2821
2822
2823
2824
2825
2826
    void allreduce(hipStream_t stream,
                   T* input,
                   T* output,
                   int size,
                   bool use_new                 = true,
                   bool is_broadcast_reg_outptr = false,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2827
#ifndef USE_ROCM
2828
2829
                   int threads     = 512,
                   int block_limit = 20){
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2830
#else
2831
2832
                   int threads     = 512,
                   int block_limit = 16)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2833
2834
    {
#endif
2835
2836
2837
2838
2839
2840
2841
2842
        auto d = 16 / sizeof(T);
    if(size % d != 0)
        throw std::runtime_error("custom allreduce currently requires input length to be multiple "
                                 "of " +
                                 std::to_string(d));
    if(block_limit > kMaxBlocks)
        throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) +
                                 ". Got " + std::to_string(block_limit));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2843

2844
2845
2846
    RankData* input_ptrs  = get_buffer_RD(stream, input);
    RankData* output_ptrs = nullptr;
    if(is_broadcast_reg_outptr)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2847
    {
2848
        output_ptrs = get_output_buffer_RD(stream, output);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2849
    }
2850
2851
2852
2853
2854
2855

    auto bytes = size * sizeof(T);
    size /= d;

    // use new version of allreduce kernel
    if(use_new)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2856
    {
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
        hipDevice_t dev;
        hipDeviceProp_t dev_prop;
        hipGetDevice(&dev);
        hipGetDeviceProperties(&dev_prop, dev);
        std::string arch    = dev_prop.gcnArchName;
        bool use_write_mode = false;

        int blocks       = 16;
        bool call_1stage = false;
        bool call_2stage = false;
        if(world_size_ == 2)
        {
            call_1stage = true;
        }
        else if(full_nvlink_)
        {
            if((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024))
            {
                call_1stage = true;
            }
            else
            {
                call_2stage = true;
            }
        }
        if(call_1stage)
        {
            blocks = std::min(kMaxBlocks,
                              (size + (threads / world_size_) - 1) / (threads / world_size_));
        }
        else if(call_2stage)
        {
            blocks = std::min(kMaxBlocks,
                              (size / world_size_ + (threads / world_size_) - 1) /
                                  (threads / world_size_));
            if(world_size_ == 8 && bytes > 512 * 4096 * 2 &&
               arch.find("gfx942") != std::string::npos)
            {
                use_write_mode = true;
            }
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2898
2899

#ifdef DTK_ENV
2900
// DTK (Hygon DCU): standard <<<>>> launch works correctly.
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2901
#define KL(ngpus, name)                                                       \
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
    do                                                                        \
    {                                                                         \
        if(is_broadcast_reg_outptr)                                           \
        {                                                                     \
            name<T, ngpus, true><<<blocks, threads, 0, stream>>>(             \
                input_ptrs, output_ptrs, sg_, self_sg_, output, rank_, size); \
        }                                                                     \
        else                                                                  \
        {                                                                     \
            name<T, ngpus, false><<<blocks, threads, 0, stream>>>(            \
                input_ptrs, output_ptrs, sg_, self_sg_, output, rank_, size); \
        }                                                                     \
    } while(0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2915
#else
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
// Non-DTK ROCm: use hipExtLaunchKernel with hipExtAddAcquireSystemScope
// for cross-device memory visibility.
#define KL(ngpus, name)                                                            \
    do                                                                             \
    {                                                                              \
        void* _args[] = {&input_ptrs, &output_ptrs, &sg_, &self_sg_,              \
                         &output, &rank_, &size};                                  \
        if(is_broadcast_reg_outptr)                                                \
        {                                                                          \
            hipExtLaunchKernel(reinterpret_cast<void*>(name<T, ngpus, true>),      \
                               dim3(blocks), dim3(threads), _args,                 \
                               0, stream, nullptr, nullptr,                        \
                               hipExtAddAcquireSystemScope);                       \
        }                                                                          \
        else                                                                       \
        {                                                                          \
            hipExtLaunchKernel(reinterpret_cast<void*>(name<T, ngpus, false>),     \
                               dim3(blocks), dim3(threads), _args,                 \
                               0, stream, nullptr, nullptr,                        \
                               hipExtAddAcquireSystemScope);                       \
        }                                                                          \
    } while(0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2938
2939
#endif

2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
#define DISPATCH_REDUCE(ngpus, name)                      \
    do                                                    \
    {                                                     \
        if(bytes % (ngpus * 16) == 0 && world_size_ != 6) \
        {                                                 \
            if(use_write_mode)                            \
            {                                             \
                KL(ngpus, name##_write_mode);             \
            }                                             \
            else                                          \
            {                                             \
                KL(ngpus, name);                          \
            }                                             \
        }                                                 \
        else                                              \
        {                                                 \
            KL(ngpus, name##_naive);                      \
        }                                                 \
Xiaowei.zhang's avatar
Xiaowei.zhang committed
2958
2959
    } while(0)

2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
#define REDUCE_CASE(ngpus)                               \
    case ngpus: {                                        \
        if(call_1stage)                                  \
        {                                                \
            KL(ngpus, cross_device_reduce_1stage);       \
        }                                                \
        else if(call_2stage)                             \
        {                                                \
            DISPATCH_REDUCE(ngpus, cross_device_reduce_2stage); \
        }                                                \
        break;                                           \
    }

        switch(world_size_)
        {
            REDUCE_CASE(2)
            REDUCE_CASE(4)
            REDUCE_CASE(6)
            REDUCE_CASE(8)
        default:
            throw std::runtime_error(
                "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
                "gpus = " +
                std::to_string(world_size_));
        }
    }
    else // use vllm allreduce kernel
    {
        int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define VLLM_REDUCE_CASE(ngpus)                              \
    case ngpus: {                                            \
        if(world_size_ == 2)                                 \
        {                                                    \
            KL(ngpus, cross_device_reduce_1stage);           \
        }                                                    \
        else if(full_nvlink_)                                \
        {                                                    \
            if((world_size_ <= 4 && bytes < 512 * 1024) ||   \
               (world_size_ <= 8 && bytes < 256 * 1024))     \
            {                                                \
                KL(ngpus, cross_device_reduce_1stage_naive); \
            }                                                \
            else                                             \
            {                                                \
                KL(ngpus, cross_device_reduce_2stage_naive); \
            }                                                \
        }                                                    \
        break;                                               \
    }

        switch(world_size_)
        {
            VLLM_REDUCE_CASE(2)
            VLLM_REDUCE_CASE(4)
            VLLM_REDUCE_CASE(6)
            VLLM_REDUCE_CASE(8)
        default:
            throw std::runtime_error(
                "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
                "gpus = " +
                std::to_string(world_size_));
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3022
3023
3024
    }
#undef REDUCE_CASE
#undef KL
3025
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3026

3027
3028
3029
template <typename T>
void dispatchReduceScatter(hipStream_t stream, T* input, T* output, int size)
{
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3030
    RankData* ptrs = get_buffer_RD(stream, input);
3031
3032
    auto d         = 16 / sizeof(T);
    int range      = size / (world_size_ * d);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3033
    dim3 block(512);
3034
3035
3036
    int block_num = (range + 511) / 512;
    dim3 grid(std::min(16, block_num));
    switch(world_size_)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3037
    {
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
    case 8:
        reduce_scatter_first_dim<T, 8>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, range);
        break;
    case 4:
        reduce_scatter_first_dim<T, 4>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, range);
        break;
    case 2:
        reduce_scatter_first_dim<T, 2>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, range);
        break;
    default: printf("reduce_scatter world_size error!\n");
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3051
    }
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
}

template <typename T>
void dispatchAllGather(
    hipStream_t stream, T* input, T* output, int size, int last_dim_size, int gather_dim)
{
    RankData* ptrs = get_buffer_RD(stream, input);
    auto d         = 16 / sizeof(T);
    dim3 block(512);
    // only support gather first dim and gather last dim
    // gather first dim
    if(gather_dim == 0)
    {
        if(size % d != 0)
        {
            int block_num = (size + 512 - 1) / 512;
            dim3 grid(std::min(block_num, 80));
            switch(world_size_)
            {
            case 8:
                allgather_naive<T, 8>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            case 4:
                allgather_naive<T, 4>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            case 2:
                allgather_naive<T, 2>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            default: printf("allgather world_size error\n");
            }
        }
        else
        {
            size /= d;
            int tnum_per_block = 512 / world_size_;
            int block_num      = (size + tnum_per_block - 1) / tnum_per_block;
            dim3 grid(std::min(block_num, 80));
            switch(world_size_)
            {
            case 8:
                allgather_vec<T, 8>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            case 4:
                allgather_vec<T, 4>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            case 2:
                allgather_vec<T, 2>
                    <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
                break;
            default: printf("allgather world_size error\n");
            }
        }
    }
    else // gather last dim
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3111
    {
3112
3113
3114
3115
3116
3117
        size /= d;
        int tnum_per_block = 512 / world_size_;
        int block_num      = (size + tnum_per_block - 1) / tnum_per_block;
        dim3 grid(std::min(block_num, 80));
        switch(world_size_)
        {
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3118
        case 8:
3119
3120
3121
            allgather_lastdim<T, 8><<<grid, block, 0, stream>>>(
                ptrs, sg_, self_sg_, output, rank_, size, last_dim_size);
            break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3122
        case 4:
3123
3124
3125
            allgather_lastdim<T, 4><<<grid, block, 0, stream>>>(
                ptrs, sg_, self_sg_, output, rank_, size, last_dim_size);
            break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3126
        case 2:
3127
3128
3129
3130
3131
            allgather_lastdim<T, 2><<<grid, block, 0, stream>>>(
                ptrs, sg_, self_sg_, output, rank_, size, last_dim_size);
            break;
        default: printf("allgather world_size error\n");
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3132
    }
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
}

template <typename T>
void dispatchFusedAllReduceRMSNorm(hipStream_t stream,
                                   T* input,
                                   T* residual_inp,
                                   T* residual_out,
                                   T* output,
                                   T* weight,
                                   float eps,
                                   int m,
                                   int n,
                                   bool use_1stage)
{
    auto d   = 16 / sizeof(T);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3148
    int size = m * n;
3149
    if(size % d != 0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3150
    {
3151
3152
3153
        throw std::runtime_error("custom allreduce currently requires input length to be multiple "
                                 "of " +
                                 std::to_string(d));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3154
3155
3156
3157
3158
3159
3160
3161
    }
    RankData* ptrs = get_buffer_RD(stream, input);
    hipDevice_t dev;
    hipDeviceProp_t dev_prop;
    hipGetDevice(&dev);
    hipGetDeviceProperties(&dev_prop, dev);
    uint32_t num_cu = dev_prop.multiProcessorCount;

3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
    auto pack_size = 16 / sizeof(T);
    use_1stage     = use_1stage && (n % pack_size == 0) && (n / pack_size <= 1024);
#define MAYBE_DISPATCH_1S_KERNEL(NGPUS)                                            \
    if(use_1stage)                                                                 \
    {                                                                              \
        allreduce_fusion_kernel_1stage_launcher<T, T, NGPUS>(ptrs,                 \
                                                             sg_,                  \
                                                             self_sg_,             \
                                                             rank_,                \
                                                             residual_inp,         \
                                                             residual_out,         \
                                                             output,               \
                                                             weight,               \
                                                             nullptr,              \
                                                             size,                 \
                                                             n,                    \
                                                             eps,                  \
                                                             stream);              \
        return;                                                                    \
    }

Xiaowei.zhang's avatar
Xiaowei.zhang committed
3183
3184
3185
3186
    // step 1, run reduce-scatter + allgather cross device save
    dim3 block(512);
    int block_num = ((size / world_size_) + 512 - 1) / 512;
    dim3 grid(std::min(block_num, 80));
3187
    switch(world_size_)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3188
    {
3189
3190
    case 8:
        MAYBE_DISPATCH_1S_KERNEL(8);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3191
#ifdef DTK_ENV
3192
3193
        reduce_scatter_cross_device_store<T, 8>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3194
#else
3195
3196
3197
3198
3199
3200
3201
        {
            void* _rs_args[] = {&ptrs, &sg_, &self_sg_, &rank_, &size};
            hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 8>),
                               grid, block, _rs_args, 0, stream, nullptr, nullptr,
                               hipExtAddAcquireSystemScope);
        }
#endif
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3202
        break;
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
    case 4:
        MAYBE_DISPATCH_1S_KERNEL(4);
#ifdef DTK_ENV
        reduce_scatter_cross_device_store<T, 4>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
#else
        {
            void* _rs_args[] = {&ptrs, &sg_, &self_sg_, &rank_, &size};
            hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 4>),
                               grid, block, _rs_args, 0, stream, nullptr, nullptr,
                               hipExtAddAcquireSystemScope);
        }
#endif
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3216
        break;
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
    case 2:
        MAYBE_DISPATCH_1S_KERNEL(2);
#ifdef DTK_ENV
        reduce_scatter_cross_device_store<T, 2>
            <<<grid, block, 0, stream>>>(ptrs, sg_, self_sg_, rank_, size);
#else
        {
            void* _rs_args[] = {&ptrs, &sg_, &self_sg_, &rank_, &size};
            hipExtLaunchKernel(reinterpret_cast<void*>(reduce_scatter_cross_device_store<T, 2>),
                               grid, block, _rs_args, 0, stream, nullptr, nullptr,
                               hipExtAddAcquireSystemScope);
        }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3229
#endif
3230
3231
        break;
    default: throw std::runtime_error("fused allreduce rmsnorm: unsupported world_size=" + std::to_string(world_size_));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3232
3233
    }

3234
3235
#undef MAYBE_DISPATCH_1S_KERNEL

Xiaowei.zhang's avatar
Xiaowei.zhang committed
3236
    // step 2, run allgather local device load + rmsnorm
3237
3238
3239
3240
3241
    int n_bytes  = n * sizeof(T);
    auto setGrid = [&](int naive_grid_size, const void* kernel_ptr) {
        int occupancy;
        hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_ptr, block.x, 0);
        grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3242
3243
    };

3244
3245
3246
3247
3248
3249
3250
3251
#define launch_fused_allreduce_rmsnorm(template_kernel)                         \
    do                                                                          \
    {                                                                           \
        auto kernel_ptr = reinterpret_cast<const void*>(template_kernel);       \
        setGrid(naive_grid_size, kernel_ptr);                                   \
        template_kernel<<<grid, block, 0, stream>>>(                            \
            sg_, residual_inp, residual_out, output, weight, eps, rank_, m, n); \
    } while(0)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3252

3253
3254
3255
3256
3257
3258
    // n_packs = number of vectorized elements per row
    constexpr int ar_pack_size = 16 / sizeof(T);
    int n_packs                = n / ar_pack_size;
    // Choose tnum (block size, must be power of 2) and n_loop
    // local_device_load_rmsnorm handles bounds check for n_packs < tnum * n_loop
    if(n_packs >= 256)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3259
    {
3260
3261
        // tnum=512, n_loop = ceil(n_packs / 512)
        int n_loop          = (n_packs + 511) / 512;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3262
        int naive_grid_size = m;
3263
        if(n_packs == 512 * n_loop)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3264
        {
3265
3266
3267
            // exact fit -> use naive (no bounds check, slightly faster)
            switch(n_loop)
            {
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3268
            case 1:
3269
3270
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 1>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3271
            case 2:
3272
3273
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 2>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3274
            case 3:
3275
3276
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 3>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3277
            case 4:
3278
3279
3280
3281
3282
3283
3284
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 512, 4>));
                break;
            default:
                throw std::runtime_error(
                    "fused allreduce rmsnorm: n too large, m=" + std::to_string(m) +
                    " n=" + std::to_string(n) + " n_loop=" + std::to_string(n_loop));
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3285
3286
3287
        }
        else
        {
3288
3289
3290
3291
3292
3293
            // non-exact -> use bounds-checked version
            switch(n_loop)
            {
            case 1:
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 1>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3294
            case 2:
3295
3296
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 2>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3297
            case 3:
3298
3299
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 3>));
                break;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3300
            case 4:
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 512, 4>));
                break;
            default:
                throw std::runtime_error(
                    "fused allreduce rmsnorm: n too large, m=" + std::to_string(m) +
                    " n=" + std::to_string(n) + " n_loop=" + std::to_string(n_loop));
            }
        }
    }
    else if(n_packs >= 64)
    {
        block.x             = 256;
        int n_loop          = (n_packs + 255) / 256;
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3314
        int naive_grid_size = m;
3315
        if(n_packs == 256 * n_loop)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3316
        {
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
            switch(n_loop)
            {
            case 1:
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 256, 1>));
                break;
            case 2:
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm_naive<T, 256, 2>));
                break;
            default:
                throw std::runtime_error(
                    "fused allreduce rmsnorm: n too large for tnum=256, m=" + std::to_string(m) +
                    " n=" + std::to_string(n) + " n_loop=" + std::to_string(n_loop));
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3330
3331
3332
        }
        else
        {
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
            switch(n_loop)
            {
            case 1:
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 256, 1>));
                break;
            case 2:
                launch_fused_allreduce_rmsnorm((local_device_load_rmsnorm<T, 256, 2>));
                break;
            default:
                throw std::runtime_error(
                    "fused allreduce rmsnorm: n too large for tnum=256, m=" + std::to_string(m) +
                    " n=" + std::to_string(n) + " n_loop=" + std::to_string(n_loop));
            }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3346
3347
3348
3349
        }
    }
    else
    {
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
        throw std::runtime_error(
            "fused allreduce rmsnorm: n too small, m=" + std::to_string(m) +
            " n=" + std::to_string(n) + " n_packs=" + std::to_string(n_packs) +
            " (need n_packs >= 64, i.e. n >= " + std::to_string(64 * ar_pack_size) + ")");
    }
}

template <typename T, typename QT>
void dispatchFusedAllReduceRMSNormQuant(hipStream_t stream,
                                        T* input,
                                        T* residual_inp,
                                        T* residual_out,
                                        QT* output,
                                        float* scale_out,
                                        T* weight,
                                        float eps,
                                        int m,
                                        int n,
                                        bool use_1stage)
{
    auto d   = 16 / sizeof(T);
    int size = m * n;
    if(size % d != 0)
    {
        throw std::runtime_error("custom allreduce currently requires input length to be multiple "
                                 "of " +
                                 std::to_string(d));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3377
    }
3378
    RankData* ptrs = get_buffer_RD(stream, input);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3379

3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
    auto pack_size   = 16 / sizeof(T);
    bool n_constrain = (n % pack_size == 0) && (n / pack_size <= 1024);
    use_1stage       = use_1stage && n_constrain;
#define DISPATCH_AR_FUSION_KERNEL(NGPUS)                                                       \
    if(use_1stage)                                                                             \
    {                                                                                          \
        allreduce_fusion_kernel_1stage_launcher<T, QT, NGPUS>(ptrs,                            \
                                                              sg_,                             \
                                                              self_sg_,                        \
                                                              rank_,                           \
                                                              residual_inp,                    \
                                                              residual_out,                    \
                                                              output,                          \
                                                              weight,                          \
                                                              scale_out,                       \
                                                              size,                            \
                                                              n,                               \
                                                              eps,                             \
                                                              stream);                         \
        return;                                                                                \
    }                                                                                          \
    else if(n_constrain && (size * sizeof(T) <= 512 * 1024))                                   \
    {                                                                                          \
        allreduce_fusion_kernel_2stage_launcher<T, QT, NGPUS>(ptrs,                            \
                                                              sg_,                             \
                                                              self_sg_,                        \
                                                              rank_,                           \
                                                              residual_inp,                    \
                                                              residual_out,                    \
                                                              output,                          \
                                                              weight,                          \
                                                              scale_out,                       \
                                                              size,                            \
                                                              n,                               \
                                                              eps,                             \
                                                              stream);                         \
        return;                                                                                \
    }                                                                                          \
    else if(n_constrain)                                                                       \
    {                                                                                          \
        allreduce_fusion_kernel_split_launcher<T, QT, NGPUS>(ptrs,                             \
                                                             sg_,                              \
                                                             self_sg_,                         \
                                                             rank_,                            \
                                                             residual_inp,                     \
                                                             residual_out,                     \
                                                             output,                           \
                                                             weight,                           \
                                                             scale_out,                        \
                                                             size,                             \
                                                             n,                                \
                                                             eps,                              \
                                                             stream);                          \
        return;                                                                                \
    }                                                                                          \
    else                                                                                       \
    {                                                                                          \
        printf("fused allreduce rmsnorm quant: n=%d not supported (must be multiple of %lu "   \
               "and n/%lu <= 1024)\n", n, pack_size, pack_size);                               \
    }

    switch(world_size_)
    {
    case 8: DISPATCH_AR_FUSION_KERNEL(8); break;
    case 4: DISPATCH_AR_FUSION_KERNEL(4); break;
    case 2: DISPATCH_AR_FUSION_KERNEL(2); break;
    default: throw std::runtime_error("fused allreduce rmsnorm: unsupported world_size=" + std::to_string(world_size_));
    }
}

template <typename T, typename QT>
void dispatchFusedAllReduceRMSNormQuantPerGroup(hipStream_t stream,
                                                T* input,
                                                T* residual_inp,
                                                T* residual_out,
                                                QT* output,
                                                float* scale_out,
                                                T* weight,
                                                float eps,
                                                int m,
                                                int n,
                                                int group_size,
                                                bool use_1stage,
                                                T* bf16_output = nullptr)
{
    auto d   = 16 / sizeof(T);
    int size = m * n;
    if(size % d != 0)
    {
        throw std::runtime_error("custom allreduce currently requires input length to be multiple "
                                 "of " +
                                 std::to_string(d));
    }
    // Per-group FP8 quant kernel constraints. The fused epilogue
    // ``ar_fusion_epilogue_per_group`` uses a butterfly ``__shfl_xor``
    // intra-group abs-max reduction with packed 16B loads, which imposes
    // the following requirements on ``group_size``:
    //
    //   (a) group_size > 0
    //   (b) group_size % PACK_SIZE == 0            (PACK_SIZE = 16/sizeof(T))
    //   (c) (group_size / PACK_SIZE) is a power of two
    //   (d) (group_size / PACK_SIZE) <= wavefront size (64 on CDNA)
    //   (e) n % group_size == 0
    //
    // Without (a)-(d) the kernel would silently produce wrong scales
    // (ill-formed butterfly stride, cross-warp shuffles, or a fractional
    // pack per group); without (e) ``num_groups = n / group_size`` would
    // not be an integer. Reject up front with an actionable message.
    constexpr int kPackSize      = 16 / sizeof(T);
    constexpr int kWavefrontSize = 64; // AMD CDNA wavefront width (gfx94x / gfx950)
    if(group_size <= 0)
    {
        throw std::runtime_error(
            "per-group quant requires group_size > 0, got group_size=" +
            std::to_string(group_size));
    }
    if(group_size % kPackSize != 0)
    {
        throw std::runtime_error(
            "per-group quant requires group_size divisible by PACK_SIZE=" +
            std::to_string(kPackSize) + " (16/sizeof(T)), got group_size=" +
            std::to_string(group_size));
    }
    int const threads_per_group_check = group_size / kPackSize;
    if((threads_per_group_check & (threads_per_group_check - 1)) != 0)
    {
        throw std::runtime_error(
            "per-group quant requires group_size/PACK_SIZE to be a power of two "
            "(butterfly __shfl_xor reduction), got group_size=" +
            std::to_string(group_size) +
            " PACK_SIZE=" + std::to_string(kPackSize) +
            " threads_per_group=" + std::to_string(threads_per_group_check));
    }
    if(threads_per_group_check > kWavefrontSize)
    {
        throw std::runtime_error(
            "per-group quant requires group_size/PACK_SIZE <= wavefront size (" +
            std::to_string(kWavefrontSize) +
            "), got group_size=" + std::to_string(group_size) +
            " PACK_SIZE=" + std::to_string(kPackSize) +
            " threads_per_group=" + std::to_string(threads_per_group_check));
    }
    if(n % group_size != 0)
    {
        throw std::runtime_error(
            "per-group quant requires n divisible by group_size, n=" +
            std::to_string(n) + " group_size=" + std::to_string(group_size));
    }
    RankData* ptrs   = get_buffer_RD(stream, input);
    auto pack_size   = 16 / sizeof(T);
    bool n_constrain = (n % pack_size == 0) && (n / pack_size <= 1024);

    use_1stage = use_1stage && n_constrain;

#define DISPATCH_AR_FUSION_PG_KERNEL(NGPUS)                                               \
    if(use_1stage)                                                                         \
    {                                                                                      \
        allreduce_fusion_kernel_1stage_per_group_launcher<T, QT, NGPUS>(                   \
            ptrs, sg_, self_sg_, rank_,                                                     \
            residual_inp, residual_out, output, weight, scale_out,                          \
            size, n, group_size, eps, stream, bf16_output);                                 \
        return;                                                                             \
    }                                                                                      \
    else if(n_constrain && (size * sizeof(T) <= 512 * 1024))                               \
    {                                                                                      \
        allreduce_fusion_kernel_2stage_per_group_launcher<T, QT, NGPUS>(                   \
            ptrs, sg_, self_sg_, rank_,                                                     \
            residual_inp, residual_out, output, weight, scale_out,                          \
            size, n, group_size, eps, stream, bf16_output);                                 \
        return;                                                                             \
    }                                                                                      \
    else if(n_constrain)                                                                   \
    {                                                                                      \
        allreduce_fusion_kernel_split_per_group_launcher<T, QT, NGPUS>(                    \
            ptrs, sg_, self_sg_, rank_,                                                     \
            residual_inp, residual_out, output, weight, scale_out,                          \
            size, n, group_size, eps, stream, bf16_output);                                 \
        return;                                                                             \
    }                                                                                      \
    else                                                                                   \
    {                                                                                      \
        throw std::runtime_error(                                                           \
            "per-group quant fused kernel: unsupported n");                                  \
    }

    switch(world_size_)
    {
    case 8: DISPATCH_AR_FUSION_PG_KERNEL(8); break;
    case 4: DISPATCH_AR_FUSION_PG_KERNEL(4); break;
    case 2: DISPATCH_AR_FUSION_PG_KERNEL(2); break;
    default:
        throw std::runtime_error(
            "fused allreduce rmsnorm per-group quant: unsupported world_size=" +
            std::to_string(world_size_));
    }
#undef DISPATCH_AR_FUSION_PG_KERNEL
}

template <typename T>
void dispatchFusedQKNormAllReduce(hipStream_t stream,
                                  T* qkv_in,
                                  T* q_w,
                                  T* k_w,
                                  T* q_out,
                                  T* k_out,
                                  T* v_out,
                                  int token_num,
                                  int hidden_dim_q,
                                  int hidden_dim_k,
                                  int hidden_dim_v,
                                  float eps)
{
    auto d = 16 / sizeof(T);
    if(hidden_dim_q % d != 0 || hidden_dim_k % d != 0 || hidden_dim_v % d != 0)
    {
        throw std::runtime_error("custom allreduce currently requires input length to be multiple "
                                 "of " +
                                 std::to_string(d));
    }
    RankData* ptrs = get_buffer_RD(stream, qkv_in);
    
#define DISPATCH_QKNORM_AR_FUSION_KERNEL(NGPUS)                                \
    {                                                                          \
        qknorm_allreduce_fusion_kernel_2stage_launcher<T, NGPUS>(ptrs,         \
                                                                 sg_,          \
                                                                 self_sg_,     \
                                                                 rank_,        \
                                                                 qkv_in,       \
                                                                 q_w,          \
                                                                 k_w,          \
                                                                 q_out,        \
                                                                 k_out,        \
                                                                 v_out,        \
                                                                 token_num,    \
                                                                 hidden_dim_q, \
                                                                 hidden_dim_k, \
                                                                 hidden_dim_v, \
                                                                 eps,          \
                                                                 stream);      \
        return;                                                                \
    }

    switch(world_size_)
    {
    case 8: DISPATCH_QKNORM_AR_FUSION_KERNEL(8); break;
    case 4: DISPATCH_QKNORM_AR_FUSION_KERNEL(4); break;
    case 2: DISPATCH_QKNORM_AR_FUSION_KERNEL(2); break;
    default:
        throw std::runtime_error("fused qknorm allreduce rmsnorm: unsupported world_size=" +
                                 std::to_string(world_size_));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3630
    }
3631
3632
3633
3634
3635
3636
3637
}

~CustomAllreduce()
{
#ifdef DTK_ENV
    if(buffer_ptr_)
        hipHostFree(buffer_ptr_);
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3638
3639
    hipEventDestroy(event_);
#endif
3640
    for(auto [_, ptr] : ipc_handles_)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3641
    {
3642
        HIP_CALL(hipIpcCloseMemHandle(ptr));
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3643
    }
3644
}
Xiaowei.zhang's avatar
Xiaowei.zhang committed
3645
3646
3647
3648
3649
3650
3651
3652
}; // namespace aiter
/**
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 a template instantiation:
 * template void aiter::CustomAllreduce::allreduce<half>(hipStream_t, half *,
 half *, int, int, int);
*/
} // namespace aiter