custom_all_reduce.cuh 51.1 KB
Newer Older
1
#pragma once
2
3
4
5
6
7
8
9
10
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <algorithm>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
11
12
13
14
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
15
16
#include <hip/hip_bf16.h>
// #if defined(USE_ROCM)
zhuwenwen's avatar
zhuwenwen committed
17
typedef __hip_bfloat16 nv_bfloat16;
18
// #endif
zhuwenwen's avatar
zhuwenwen committed
19

20
#include <iostream>
21
#include <array>
22
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
23
#include <map>
24
25
#include <unordered_map>
#include <vector>
26
27
28
29
30
#ifndef USE_ROCM
  #include <cub/cub.cuh>
#else
  #include <hipcub/hipcub.hpp>
#endif
zhuwenwen's avatar
zhuwenwen committed
31
namespace vllm {
32
33
34
35
36
37
38
39
40
41
#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

zhuwenwen's avatar
zhuwenwen committed
42
// Maximal number of blocks in allreduce kernel.
43
constexpr int kMaxBlocks = 128;
zhuwenwen's avatar
zhuwenwen committed
44
45
46
47

// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
48
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
zhuwenwen's avatar
zhuwenwen committed
49
50
51
52
53
54
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
    HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif

55
56
57
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
zhuwenwen's avatar
zhuwenwen committed
58

59
60
61
62
63
64
// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
65
struct Signal {
zhuwenwen's avatar
zhuwenwen committed
66
67
  alignas(128) FlagType start[kMaxBlocks][16];
  alignas(128) FlagType end[kMaxBlocks][16];
zhuwenwen's avatar
zhuwenwen committed
68
  alignas(128) FlagType _flag[kMaxBlocks];  // incremental flags for each rank
69
70
};

71
struct __align__(16) RankData {
zhuwenwen's avatar
zhuwenwen committed
72
  const void* ptrs[16];
73
};
74

75
struct __align__(16) RankSignals {
zhuwenwen's avatar
zhuwenwen committed
76
  Signal* signals[16];
77
};
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
95
  using F = array_t<int8_t, 16 / sizeof(T)>;
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
113
DINLINE half& assign_add(half& a, half b) {
114
115
116
  a = __hadd(a, b);
  return a;
}
117
DINLINE float& assign_add(float& a, float b) { return a += b; }
118

zhuwenwen's avatar
zhuwenwen committed
119
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
120
121
122
123
124
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
125
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
126
127
128
  a = __hadd(a, b);
  return a;
}
zhuwenwen's avatar
zhuwenwen committed
129
// #endif
130
131

template <typename T, int N>
132
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
133
134
135
136
137
138
139
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
/**********************************************************/
template <typename P, uint32_t VEC_SIZE>
DINLINE P vec_add(const P& a, const P& b) {
  P sum_tmp;
  #pragma unroll
  for (int i = 0; i < a.size; ++i)
    sum_tmp.data[i] = static_cast<float>(a.data[i]) + static_cast<float>(b.data[i]);
  return sum_tmp;
}

template <typename T, int reducesize=64>
__inline__ __device__ T WarpReduceSum(T val) {
  #pragma unroll
  for (int offset = reducesize / 2; offset > 0; offset >>= 1) {
    val += WARP_SHFL_DOWN(val, offset);
  }
  return val;
}

template <typename T>
DINLINE T BlockReduce(T val, T* shared) {
  const int lid = threadIdx.x % 64;
  const int wid = threadIdx.x / 64;
  const int block_size = blockDim.x;
  const int shared_size = block_size / 64;
  val = WarpReduceSum<T>(val);
  if(block_size==64) return val;
  if (lid == 0 && wid < shared_size) {
    shared[wid] = val;
  }
  __syncthreads();
  val = 0.f;
  if (wid == 0 && lid < shared_size) {
    val= shared[lid];
    val = WarpReduceSum<T, 16>(val);
  }
  return val;
}

template <typename T, typename P, typename A>
DINLINE P fused_add_rms_norm(P const& residual, P const& gamma, int hidden_dim, float eps) {
  static constexpr int VEC_SIZE = 16 / sizeof(T);
  __shared__ float s_val;
  float trstd;
  P norm_out;
  float acc = 0.0f;
  #pragma unroll
  for (int i = 0; i < VEC_SIZE; ++i) {
    float v = static_cast<float>(residual.data[i]);
    acc += v * v;
  }
  __shared__ float r_sum[16];
  acc = BlockReduce(acc, r_sum);
  if (threadIdx.x == 0)
    s_val = rsqrtf(acc / hidden_dim + eps);
  __syncthreads();
  trstd = s_val;
  #pragma unroll
  for (int i = 0; i < VEC_SIZE; ++i) {
    norm_out.data[i] = static_cast<T>(static_cast<float>(residual.data[i]) * trstd * static_cast<float>(gamma.data[i]));
  }
  return norm_out;
}

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
  static constexpr auto i8_min =
      static_cast<float>(std::numeric_limits<int8_t>::min());
  static constexpr auto i8_max =
      static_cast<float>(std::numeric_limits<int8_t>::max());

  float dst = std::nearbyint(x);

  dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
  return static_cast<int8_t>(dst);
#else
  // CUDA path
  uint32_t dst;
  asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
  return reinterpret_cast<const int8_t&>(dst);
#endif
}

template <typename T, int reducesize=64>
__inline__ __device__ T WarpReduceMax(T val) {
  #pragma unroll
  for (int offset = reducesize / 2; offset > 0; offset >>= 1) {
    val = fmaxf(val, WARP_SHFL_DOWN(val, offset));
  }
  return val;
}

template <typename T>
DINLINE T BlockReduceMax_ROW(T val, T* shared) {
  const int lid = threadIdx.x % 64;
  const int wid = threadIdx.x / 64;
  const int block_size = blockDim.x;
  const int shared_size = block_size / 64;
  val = WarpReduceMax<T>(val);
  if(block_size==64) return val;
  if (lid == 0 && wid < shared_size) {
    shared[wid] = val;
  }
  __syncthreads();
  if (wid == 0 && lid<shared_size) {
    val= shared[lid];
    val = WarpReduceMax<T, 16>(val);
  }
  return val;
}

251
252
253
254
255
256
257
258
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
259
      out.data[i] = static_cast<float>(val.data[i]);
260
261
262
263
264
265
266
267
268
269
270
271
272
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
273
      out.data[i] = static_cast<typename O::type>(val.data[i]);
274
275
276
277
278
    }
    return out;
  }
}

279
#if 0
zhuwenwen's avatar
zhuwenwen committed
280

281
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
zhuwenwen's avatar
zhuwenwen committed
282
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
283
284
  asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
285
  #else
286
287
  asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
288
  #endif
289
290
291
292
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  FlagType flag;
zhuwenwen's avatar
zhuwenwen committed
293
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
294
295
296
  asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
297
  #else
298
299
300
  asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
               : "=r"(flag)
               : "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
301
  #endif
302
303
304
305
306
307
308
309
310
311
312
313
314
  return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  FlagType flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
  return flag;
315
316
}

zhuwenwen's avatar
zhuwenwen committed
317
318
319
320
321
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
322
323
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
zhuwenwen's avatar
zhuwenwen committed
324
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
325
  if (threadIdx.x < ngpus) {
zhuwenwen's avatar
zhuwenwen committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
    // Write the expected counter value to peer and wait for correct value
    // from peer.
    st_flag_volatile(peer_counter_ptr, flag);
    while (ld_flag_volatile(self_counter_ptr) != flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
343
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
zhuwenwen's avatar
zhuwenwen committed
344
345
346
347
348
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
349
350
    // Write the expected counter value to peer and wait for correct value from
    // peer.
zhuwenwen's avatar
zhuwenwen committed
351
352
353
    if constexpr (!final_sync) {
      st_flag_release(peer_counter_ptr, flag);
      while (ld_flag_acquire(self_counter_ptr) != flag);
354
    } else {
zhuwenwen's avatar
zhuwenwen committed
355
356
      st_flag_volatile(peer_counter_ptr, flag);
      while (ld_flag_volatile(self_counter_ptr) != flag);
357
    }
358
  }
zhuwenwen's avatar
zhuwenwen committed
359
360
361
362
363
364
365
366
367
  if constexpr (!final_sync) __syncthreads();

  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

#else

template <int ngpus>
368
369
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
370
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;  //当前线程块标记+1
zhuwenwen's avatar
zhuwenwen committed
371
372
373
374
375
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
    //                         flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
376
    // 将每个peer GPU对应线程块的本rank flag填入
zhuwenwen's avatar
zhuwenwen committed
377
378
379
    __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
      __ATOMIC_RELAXED);
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
380
381
382
    // while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
    //                               __ATOMIC_RELAXED,
    //                               __MEMORY_SCOPE_DEVICE) < flag);
383
    //等待对应blockidx.x处理的数据的peer gpu到达
zhuwenwen's avatar
zhuwenwen committed
384
385
386
387
388
389
390
391
392
    while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
      __ATOMIC_RELAXED) < flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

template <int ngpus, bool final_sync = false>
393
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
zhuwenwen's avatar
zhuwenwen committed
394
395
396
397
398
399
400
401
402
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
    //                         flag,
    //                         final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
    //                         __MEMORY_SCOPE_SYSTEM);
403
    // 告诉其他GPU 本block Reduce完毕
zhuwenwen's avatar
zhuwenwen committed
404
405
    __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
      final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
406
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
407
408
409
410
    // while (
    //     __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
    //                            final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
    //                            __MEMORY_SCOPE_DEVICE) < flag);
411
    // 当前block处理的 hs的其他GPU处理完毕
zhuwenwen's avatar
zhuwenwen committed
412
413
414
415
416
417
418
    while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
    flag);
  }
  if constexpr (!final_sync) __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
419
420
}

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end_fuse(const RankSignals& sg, Signal* self_sg, int rank) {
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
    //                         flag,
    //                         final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
    //                         __MEMORY_SCOPE_SYSTEM);
    __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
      final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
    // wait until we got true from all ranks
    // while (
    //     __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
    //                            final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
    //                            __MEMORY_SCOPE_DEVICE) < flag);
    while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
    flag);
  }
   __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}


zhuwenwen's avatar
zhuwenwen committed
449
450
#endif

451
template <typename P, int ngpus, typename A>
452
DINLINE P packed_reduce(const P* ptrs[], int idx) {
453
454
455
456
457
458
459
460
461
462
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
463
464
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
465
466
467
468
469
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
470
  barrier_at_start<ngpus>(sg, self_sg, rank);
471
472
473
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
474
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
475
  }
476
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
477
478
479
}

template <typename P>
480
DINLINE P* get_tmp_buf(Signal* sg) {
481
  return (P*)(((Signal*)sg) + 1);
482
483
}

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
template <typename T, int ngpus>
__global__ void __launch_bounds__(1024, 1)
    cross_device_reduce_2stage_fuse_norm(RankData* _dp, RankSignals sg, Signal* self_sg,
                                        T* __restrict__ result, int rank, int size,
                                        int hidden_dim, T* residual_in, T* rms_gamma,
                                        float eps, std::array<int, ngpus> begin_tokens,
                                        std::array<int, ngpus> token_num_per_ranks) {
  static constexpr int VEC_SIZE = 16 / sizeof(T);
  int H_D_word_num = hidden_dim / VEC_SIZE;
  int token_id = blockIdx.x;                      // local token id
  int access_id_in_token = threadIdx.x;           // 当前token内数据部分
  int token_stride = gridDim.x;
  //
  int access_id = token_id * H_D_word_num + access_id_in_token;  // local token id * (token in size) 
  int access_stride = token_stride * H_D_word_num;               // gridDim.x * (token in size)

  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;

  const P* ptrs[ngpus];
  P* tmps[ngpus];
  #pragma unroll
  for (int i = 0; i < ngpus; ++i) {
    int target = (rank + i) % ngpus;   
    ptrs[i] = (const P*)_dp->ptrs[target];
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }

  int start = begin_tokens[rank] * H_D_word_num;
  int part = (begin_tokens[rank] + token_num_per_ranks[rank]) * H_D_word_num;

  auto tmp_out = tmps[0];   // 当前rank的 (meta_data + sizeof(signal))  偏移
  barrier_at_start<ngpus>(sg, self_sg, rank);

  #pragma unroll
  for (int idx = access_id + start; idx < part; idx+=access_stride) {
    tmp_out[idx] = packed_reduce<P, ngpus, A>(ptrs, idx);     
  #pragma unroll
  for (int r = 0; r < ngpus; ++r)
    tmps[r][idx] = tmp_out[idx];   //将当前GPU处理的数据--->其他GPU的对应问题
  }
  barrier_at_end<ngpus>(sg, self_sg, rank);

  //debug --- 验证reduce结果
  // for (int r = 0; r < ngpus; ++r) {
  //   int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
  //   int cm_token_id = token_id + begin_tokens[r];
  //   int cm_token_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
  //   for (int idx = cm_access_id; idx < cm_token_access; idx += access_stride)
  //   ((P*)result)[idx] = tmp_out[idx];
  // }
  P m_residual_val, m_gamm_val;
  m_gamm_val = ((P*)rms_gamma)[access_id_in_token];
  #pragma unroll
  for (int r = 0; r < ngpus; ++r) { 
    int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
    int cm_token_id = token_id + begin_tokens[r];
    int cm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
    for (int idx = cm_access_id; idx < cm_tot_access; idx += access_stride) {
      P sum_val;
      sum_val = tmp_out[idx];
      m_residual_val =((P*)residual_in)[idx];
      sum_val = vec_add<P, VEC_SIZE>(sum_val, m_residual_val);

      sum_val = fused_add_rms_norm<T, P, A>(sum_val, m_gamm_val, hidden_dim, eps);
      ((P*)result)[idx] = sum_val;
    }
  }
}

template <typename T, typename T_out, int ngpus, bool isResidual=true, bool update_input=false>
__global__ void __launch_bounds__(1024, 1)
    cross_device_reduce_1stage_norm_quant(RankData* _dp, RankSignals sg, Signal* self_sg,
                                          T_out* __restrict__ result, int rank, int size,
                                          int hidden_dim, T* residual_in, T* rms_gamma,
                                          float* __restrict__ scales, float eps,
                                          T* __restrict__ norm_res) {
  // static constexpr int VEC_SIZE = 16 / sizeof(T);
  static constexpr int VEC_SIZE = packed_t<T>::P::size;
  int H_D_word_num = hidden_dim / VEC_SIZE;
  int token_id = blockIdx.x; 
  int access_id_in_token = threadIdx.x; 
  int token_stride = gridDim.x;
  int access_id = token_id * H_D_word_num + access_id_in_token;
  int access_stride = token_stride * H_D_word_num; 

  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  using F = typename packed_t<T>::F;

  P m_residual_val, m_gamm_val;
  m_gamm_val = reinterpret_cast<P*>(rms_gamma)[access_id_in_token];
  auto dp = *_dp;
  P sum_val;
  barrier_at_start<ngpus>(sg, self_sg, rank);
  sum_val = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], access_id);
  barrier_at_end<ngpus, true>(sg, self_sg, rank);

  if constexpr(isResidual) {
    m_residual_val = reinterpret_cast<P*>(residual_in)[access_id];
    sum_val = vec_add<P, VEC_SIZE>(m_residual_val, sum_val);
    ((P*)residual_in)[access_id] = sum_val;
  }

  __shared__ float s_val;
  P norm_out;
  float acc = 0.f;
  #pragma unroll
  for (int i = 0; i < VEC_SIZE; ++i) {
    float v = static_cast<float>(sum_val.data[i]);
    acc += v * v;
  }
  __shared__ float r_sum[16];
  acc = BlockReduce<float>(acc, r_sum);
  if (threadIdx.x == 0)
    s_val = rsqrt(acc / hidden_dim + eps);
  __syncthreads();

  float block_absmax_val_maybe = 0.f;
  #pragma unroll
  for (int i = 0; i < VEC_SIZE; ++i) {
    norm_out.data[i] = static_cast<float>(sum_val.data[i]) * s_val * static_cast<float>(m_gamm_val.data[i]);
    block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabs(norm_out.data[i]));
  }

  block_absmax_val_maybe = BlockReduceMax_ROW(block_absmax_val_maybe,r_sum);
  //
  __shared__ float s_token_scale;
  float scale = 0.0f;
  if (threadIdx.x == 0) {
    scale = block_absmax_val_maybe;
    s_token_scale = scale;
  }
  __syncthreads();
  float inv_s = (s_token_scale == 0.f) ? 0.f : 127.f / s_token_scale;
  F out_vec;
  #pragma unroll
  for (int i = 0; i < VEC_SIZE; ++i)
    out_vec.data[i] = float_to_int8_rn(norm_out.data[i] * inv_s);
  constexpr float qmax = 127.0f;
  constexpr float min_scale = 1.19209e-07f;
  ((F*)result)[access_id] = out_vec;
  if constexpr (update_input) 
    ((P*)norm_res)[access_id] = norm_out;
  if (threadIdx.x == 0)
    scales[blockIdx.x] = fmaxf(scale/qmax, min_scale);
}

template <typename T, typename T_out, int ngpus, bool isResidual=true, bool update_input=false>
__global__ void __launch_bounds__(1024, 1)
    cross_device_reduce_2stage_fuse_norm_quant(RankData* _dp, RankSignals sg, Signal* self_sg,
                                              T_out* __restrict__ result, int rank, int size,
                                              int hidden_dim, T* residual_in, T* rms_gamma,
                                              float* __restrict__ scales, float eps, 
                                              T* __restrict__ norm_res,
                                              std::array<int, ngpus> begin_tokens,
                                              std::array<int, ngpus> token_num_per_ranks) {
  static constexpr int VEC_SIZE = 16 / sizeof(T);
  int H_D_word_num = hidden_dim / VEC_SIZE;
  int token_id = blockIdx.x;                      // local token id
  int access_id_in_token = threadIdx.x;           // 当前token内数据部分
  int token_stride = gridDim.x;
  //
  int access_id = token_id * H_D_word_num + access_id_in_token;  // local token id * (token in size) 
  int access_stride = token_stride * H_D_word_num;               // gridDim.x * (token in size)

  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  using F = typename packed_t<T>::F;

  const P* ptrs[ngpus];
  P* tmps[ngpus];
  #pragma unroll
  for (int i = 0; i < ngpus; ++i) {
    int target = (rank + i) % ngpus;   
    ptrs[i] = (const P*)_dp->ptrs[target];
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }

  int start = begin_tokens[rank] * H_D_word_num;
  int part = (begin_tokens[rank] + token_num_per_ranks[rank]) * H_D_word_num;

  auto tmp_out = tmps[0];   // 当前rank的 (meta_data + sizeof(signal))  偏移
  auto input = ptrs[0];
  barrier_at_start<ngpus>(sg, self_sg, rank);

  #pragma unroll
  for (int idx = access_id + start; idx < part; idx+=access_stride) {
    tmp_out[idx] = packed_reduce<P, ngpus, A>(ptrs, idx);     
  #pragma unroll
  for (int r = 0; r < ngpus; ++r)
    tmps[r][idx] = tmp_out[idx];   //将当前GPU处理的数据--->其他GPU的对应问题
  }
  barrier_at_end<ngpus>(sg, self_sg, rank);

  P m_residual_val, m_gamm_val;
  m_gamm_val = reinterpret_cast<P*>(rms_gamma)[access_id_in_token];
  #pragma unroll
  for (int r = 0; r < ngpus; ++r) { 
    int cm_access_id = access_id + begin_tokens[r] * H_D_word_num;
    int cm_token_id = token_id + begin_tokens[r];
    int cm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * H_D_word_num;
    for (int idx = cm_access_id, tidx = cm_token_id; idx < cm_tot_access; 
              idx += access_stride, tidx += token_stride) {
      P sum_val;
      sum_val = tmp_out[idx];
      if constexpr (isResidual) {
        m_residual_val = reinterpret_cast<P*>(residual_in)[idx];
        sum_val = vec_add<P, VEC_SIZE>(sum_val, m_residual_val);
        ((P*)residual_in)[idx] = sum_val;
      }
      __shared__ float s_val;
      P norm_out;
      float acc = 0.0f;
      #pragma unroll
      for (int i = 0; i < VEC_SIZE; ++i) {
        float v = static_cast<float>(sum_val.data[i]);
        acc += v * v;
      }
      __shared__ float r_sum[16];
      acc = BlockReduce(acc, r_sum);
      if (threadIdx.x == 0)
        s_val = rsqrtf(acc / hidden_dim + eps);
      __syncthreads();

      float block_absmax_val_maybe = 0.f;
      #pragma unroll
      for (int i = 0; i < VEC_SIZE; ++i) {
        norm_out.data[i] = static_cast<T>(static_cast<float>(sum_val.data[i]) * s_val * static_cast<float>(m_gamm_val.data[i]));
        block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabs(norm_out.data[i]));
      }

      block_absmax_val_maybe = BlockReduceMax_ROW(block_absmax_val_maybe, r_sum);
      __shared__ float s_token_scale;
      float scale = 0.0f;
      if (threadIdx.x == 0) {
        scale = block_absmax_val_maybe;
        s_token_scale = scale;
      }
      __syncthreads();
      float inv_s = (s_token_scale == 0.f) ? 0.f : 127.f / s_token_scale;
      
      F out_vec;
      #pragma unroll
      for (int i = 0; i < VEC_SIZE; ++i)
        out_vec.data[i] = float_to_int8_rn(norm_out.data[i] * inv_s);

      constexpr float qmax = 127.0f;
      constexpr float min_scale = 1.19209e-07f;
      ((F*)result)[idx] = out_vec;
      if constexpr (update_input)
        ((P*)norm_res)[idx] = norm_out;
      if (threadIdx.x == 0)
        scales[tidx] = fmaxf(scale/qmax, min_scale);
    }
  }
}

742
743
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
744
745
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
746
747
748
749
750
751
752
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
753
  int largest_part = part + size % ngpus;
754
755
  const P* ptrs[ngpus];
  P* tmps[ngpus];
756
757
758
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
759
    ptrs[i] = (const P*)_dp->ptrs[target];
760
761
762
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
763
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
764

765
766
767
768
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
769
  barrier_at_end<ngpus>(sg, self_sg, rank);
770
771
772
773

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
774
775
776
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

777
  for (int idx = tid; idx < largest_part; idx += stride) {
778
779
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
780
781
782
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
783
        ((P*)result)[dst_idx] = tmps[i][idx];
784
      }
785
786
787
788
    }
  }
}

zhuwenwen's avatar
zhuwenwen committed
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
    cross_device_reduce_1stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size,
                               uint32_t** curr_hdp_reg, int world_size) {

  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
  if (threadIdx.x == 1) {
    for(int i = 0; i < world_size; i++) {
      __atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
    }
  }
805
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
806
807
808
809
810
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
  }
811
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
    cross_device_reduce_2stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size,
                               uint32_t** curr_hdp_reg, int world_size) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
  int largest_part = part + size % ngpus;
  const P* ptrs[ngpus];
  P* tmps[ngpus];
  if (threadIdx.x == 1) {
    for(int i = 0; i < world_size; i++) {
      __atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
    }
  }
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
    ptrs[i] = (const P*)_dp->ptrs[target];
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
841
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
842
843
844
845
846

  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
847
  barrier_at_end<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

  for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
        ((P*)result)[dst_idx] = tmps[i][idx];
      }
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
867
868
869
870
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

871
872
873
874
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
875
876
  // Full NVLink or xGMI connection between GPUs.
  bool fully_connected_;
877
878

  RankSignals sg_;
Tianer Zhou's avatar
Tianer Zhou committed
879
  // Stores a map from a pointer to its peer pointers from all ranks.
880
881
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
882

883
884
  // Stores rank data from all ranks. This is mainly for cuda graph purposes.
  // For cuda graph to work, all kernel arguments must be fixed during graph
885
886
887
888
889
890
  // capture time. However, the peer pointers are not known during graph
  // capture time. Therefore, during capture, we increment the rank data
  // pointer and use that as the argument to the kernel. The kernel arguments
  // are stored in graph_unreg_buffers_. The actual peer pointers will be
  // filled in at the memory pointed to by the pointers in
  // graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
891
892
893
894
895
896
897
898
  //
  // The overall process looks like this:
  // 1. Graph capture.
  // 2. Each rank obtains the IPC handles for each addresses used during cuda
  // graph capture using get_graph_buffer_ipc_meta.
  // 3. (In Python) all gather the IPC handles.
  // 4. Obtain the peer pointers by opening the IPC handles, and store them in
  // the rank data array at corresponding positions.
899
  RankData *d_rank_data_base_, *d_rank_data_end_;
900
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
901
  // a map from IPC handles to opened IPC pointers
902
  std::map<IPC_KEY, char*> ipc_handles_;
903

zhuwenwen's avatar
zhuwenwen committed
904
  uint32_t** dev_curr_hdp_reg;
905
  /**
906
907
908
   * Signals are an array of ipc-enabled buffers from all ranks.
   * For each of the buffer, the layout is as follows:
   * | -- sizeof(Signal) -- | ------ a few MB ----- |
909
910
911
   * The first section is for allreduce synchronization, and the second
   * section is for storing the intermediate results required by some
   * allreduce algos.
912
   *
913
914
   * Note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor.
915
   */
916
  CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
917
                  int rank, int world_size, bool fully_connected = true)
918
      : rank_(rank),
919
        world_size_(world_size),
920
        fully_connected_(fully_connected),
921
        self_sg_(signals[rank]),
922
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
923
924
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
925
      sg_.signals[i] = signals[i];
926
    }
zhuwenwen's avatar
zhuwenwen committed
927
928
929
930
931
932
    if (!fully_connected) {
      cudaMalloc((void**)&dev_curr_hdp_reg, world_size_ * sizeof(uint32_t*));
      for (int i = 0; i < world_size_; ++i) {
        hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i);
      }
    }
933
934
  }

935
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
936
    auto [it, new_handle] =
937
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
938
    if (new_handle) {
939
940
941
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
942
943
944
945
946
947
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

948
  std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
949
950
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
951
    std::string handles(handle_sz * num_buffers, static_cast<char>(0));
952
953
954
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
955
      void* base_ptr;
956
957
      // note: must share the base address of each allocation, or we get wrong
      // address
zhuwenwen's avatar
zhuwenwen committed
958
      if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
959
960
961
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
962
963
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
964
965
966
967
968
969
970
971
972
973
974
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

975
976
977
978
  /**
   * Register already-shared IPC pointers.
   */
  void register_buffer(void** ptrs) {
979
980
981
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
982
      data.ptrs[i] = ptrs[i];
983
984
985
986
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
987
    buffers_[ptrs[rank_]] = d_data;
988
989
  }

990
  // Note: when registering graph buffers, we intentionally choose to not
991
  // deduplicate the addresses. That means if the allocator reuses some
992
993
994
995
996
  // addresses, they will be registered again. This is to account for the
  // remote possibility of different allocation patterns between ranks. For
  // example, rank 1 may get the same input address for the second allreduce,
  // but rank 2 got a different address. IPC handles have internal reference
  // counting mechanism so overhead should be small.
997
  void register_graph_buffers(
998
999
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
1000
1001
1002
1003
1004
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
1005
      auto& rd = rank_data[i];
1006
1007
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
1008
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
1009
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

zhuwenwen's avatar
zhuwenwen committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
  template <typename T>
  void allreduce_pcie(cudaStream_t stream, T* input, T* output, int size,
                 int threads = 512, int block_limit = defaultBlockLimit) {
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));

    RankData* ptrs;
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                  rank_, size, dev_curr_hdp_reg, world_size_) ;

#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage_pcie);     \
    } else {                                          \
      if ((world_size_ <= 4 && bytes < 128 * 8192) || \
          (world_size_ <= 8 && bytes < 8 * 8192)) {  \
        KL(ngpus, cross_device_reduce_1stage_pcie);   \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage_pcie);   \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      REDUCE_CASE(16)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
            "num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

1093
  /**
1094
1095
   * Performs allreduce, assuming input has already been registered.
   *
1096
1097
1098
1099
1100
   * Block and grid default configs are results after careful grid search.
   * Using 36 blocks give the best or close to the best runtime on the devices
   * I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
   * only take a small amount of SMs. Not quite sure the underlying reason,
   * but my guess is that too many SMs will cause contention on NVLink bus.
1101
   */
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272

  template <typename T>
  void allreduce_fuse_norm(cudaStream_t stream, T* input, T* output, int size,
                          int token_num, int hidden_dim, T* residual, T* rms_weight,
                          double eps, int threads = 512, int block_limit = defaultBlockLimit) {
    auto d = packed_t<T>::P::size;
    if (hidden_dim % d != 0)
        throw std::runtime_error(
            "custom allreduce currently requires input length to be multiple "
            "of " +
            std::to_string(d));
    if (block_limit > kMaxBlocks)
        throw std::runtime_error("max supported block limit is " +
                                std::to_string(kMaxBlocks) + ". Got " +
                                std::to_string(block_limit));

    RankData* ptrs;
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    int block_num = token_num;

#define KL(ngpus, name)                                                 \
    std::array<int, ngpus> begin_tokens, token_num_per_ranks;           \
    int remaining_token = token_num % ngpus;                            \
    int token_num_per_rank = token_num / ngpus;                         \
    block_num = token_num_per_rank;                                     \
    if (remaining_token)                                                \
      block_num++;                                                      \
    for (int i = 0; i < ngpus; ++i) {                                   \
      begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token);     \
      token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0);                \
    }                                                                                             \
    int thread_per_token = hidden_dim / d;                                                        \
    int grid_size = std::min(kMaxBlocks, block_num);                                              \
    int threads_in_block = thread_per_token;                                                      \
  name<T, ngpus><<<grid_size, threads_in_block, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size, hidden_dim, residual, \
                                                 rms_weight, eps, begin_tokens, token_num_per_ranks); 
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_2stage_fuse_norm);          \
    } else if (fully_connected_) {                    \
      if ((world_size_ <= 4) || \
          (world_size_ <= 8 )) { \
        KL(ngpus, cross_device_reduce_2stage_fuse_norm);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage_fuse_norm);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  } 

 template<typename scalar_in_t, typename scalar_out_t, bool isResidual=true, bool update_input=false>
  void allreduce_fuse_norm_quant(cudaStream_t stream, scalar_in_t* input, scalar_out_t* output, int size,
                                 int token_num, int hidden_dim, scalar_in_t* residual, scalar_in_t* rms_weight,
                                 scalar_in_t* norm_out,
                                 double eps, float* scales, int threads = 512, int block_limit = defaultBlockLimit) {
    auto d = packed_t<scalar_in_t>::P::size;
    if (hidden_dim % d != 0)
        throw std::runtime_error(
            "custom allreduce currently requires input length to be multiple "
            "of " +
            std::to_string(d));
    if (block_limit > kMaxBlocks)
        throw std::runtime_error("max supported block limit is " +
                                std::to_string(kMaxBlocks) + ". Got " +
                                std::to_string(block_limit));

    RankData* ptrs;
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    int block_num = token_num;
    int thread_per_token = hidden_dim / d;  
    auto bytes = (size / d) * sizeof(typename packed_t<scalar_in_t>::P);
#define KL1(ngpus, name)                                                 \
    std::array<int, ngpus> begin_tokens, token_num_per_ranks;           \
    int remaining_token = token_num % ngpus;                            \
    int token_num_per_rank = token_num / ngpus;                         \
    block_num = token_num_per_rank;                                     \
    if (remaining_token)                                                \
      block_num++;                                                      \
    for (int i = 0; i < ngpus; ++i) {                                   \
      begin_tokens[i] = i * token_num_per_rank + (remaining_token > i ? i : remaining_token);     \
      token_num_per_ranks[i] = token_num_per_rank + (remaining_token > i ? 1 : 0);                \
    }                                                                                             \                                                      
    int grid_size = std::min(kMaxBlocks, block_num);                                              \
    int threads_in_block = thread_per_token;                                                      \
  name<scalar_in_t, scalar_out_t, ngpus, isResidual, update_input><<<block_num, threads_in_block, 0, stream>>>(ptrs, sg_,  \
                                                self_sg_, output,  rank_, size, hidden_dim, residual,        \
                                                 rms_weight, scales, eps, norm_out, begin_tokens, token_num_per_ranks); 
#define KL(ngpus, name)                                              \
    name<scalar_in_t, scalar_out_t, ngpus, isResidual, update_input><<<block_num, thread_per_token, 0, stream>>>(ptrs, sg_, \
                                                self_sg_, output, rank_, size, hidden_dim, residual, rms_weight,             \
                                                scales, eps, norm_out);                                
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage_norm_quant);          \
    } else if (fully_connected_) {                    \
      if ((world_size_ <= 4 && bytes < 1024 * 1024) || \
          (world_size_ <= 8 && bytes < 512 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage_norm_quant);        \
      } else {                                        \
        KL1(ngpus, cross_device_reduce_2stage_fuse_norm_quant);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }




1273
  template <typename T>
1274
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
zhuwenwen's avatar
zhuwenwen committed
1275
                 int threads = 512, int block_limit = defaultBlockLimit) {
1276
1277
1278
1279
1280
1281
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
1282
1283
1284
1285
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
1286

1287
    RankData* ptrs;
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
1306
1307
1308
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
1309
1310
1311
1312
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage);          \
1313
    } else if (fully_connected_) {                    \
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
zhuwenwen's avatar
zhuwenwen committed
1331
1332
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
1333
1334
1335
1336
1337
1338
1339
1340
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
1341
    for (auto [_, ptr] : ipc_handles_) {
1342
1343
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
zhuwenwen's avatar
zhuwenwen committed
1344
    cudaFree(dev_curr_hdp_reg);
1345
1346
  }
};
1347

1348
/**
1349
1350
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and
 add a template instantiation:
1351
1352
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
1353
*/
1354
}  // namespace vllm