custom_all_reduce.cuh 27.6 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

zhuwenwen's avatar
zhuwenwen committed
8
9
10
11
#if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16;
#endif

12
#include <iostream>
13
#include <array>
14
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
15
#include <map>
16
17
18
#include <unordered_map>
#include <vector>

zhuwenwen's avatar
zhuwenwen committed
19
namespace vllm {
20
21
22
23
24
25
26
27
28
29
#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

zhuwenwen's avatar
zhuwenwen committed
30
// Maximal number of blocks in allreduce kernel.
31
constexpr int kMaxBlocks = 36;
zhuwenwen's avatar
zhuwenwen committed
32
33
34
35

// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
36
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
zhuwenwen's avatar
zhuwenwen committed
37
38
39
40
41
42
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
    HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif

43
44
45
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
zhuwenwen's avatar
zhuwenwen committed
46

47
48
49
50
51
52
// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
53
struct Signal {
zhuwenwen's avatar
zhuwenwen committed
54
55
  alignas(128) FlagType start[kMaxBlocks][16];
  alignas(128) FlagType end[kMaxBlocks][16];
zhuwenwen's avatar
zhuwenwen committed
56
  alignas(128) FlagType _flag[kMaxBlocks];  // incremental flags for each rank
57
58
};

59
struct __align__(16) RankData {
zhuwenwen's avatar
zhuwenwen committed
60
  const void* ptrs[16];
61
};
62

63
struct __align__(16) RankSignals {
zhuwenwen's avatar
zhuwenwen committed
64
  Signal* signals[16];
65
};
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
100
DINLINE half& assign_add(half& a, half b) {
101
102
103
  a = __hadd(a, b);
  return a;
}
104
DINLINE float& assign_add(float& a, float b) { return a += b; }
105

zhuwenwen's avatar
zhuwenwen committed
106
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
107
108
109
110
111
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
112
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
113
114
115
  a = __hadd(a, b);
  return a;
}
zhuwenwen's avatar
zhuwenwen committed
116
// #endif
117
118

template <typename T, int N>
119
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

zhuwenwen's avatar
zhuwenwen committed
155
156
#if !defined(USE_ROCM)

157
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
zhuwenwen's avatar
zhuwenwen committed
158
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
159
160
  asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
161
  #else
162
163
  asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
164
  #endif
165
166
167
168
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  FlagType flag;
zhuwenwen's avatar
zhuwenwen committed
169
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
170
171
172
  asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
173
  #else
174
175
176
  asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
               : "=r"(flag)
               : "l"(flag_addr));
zhuwenwen's avatar
zhuwenwen committed
177
  #endif
178
179
180
181
182
183
184
185
186
187
188
189
190
  return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  FlagType flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
  return flag;
191
192
}

zhuwenwen's avatar
zhuwenwen committed
193
194
195
196
197
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
198
199
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
zhuwenwen's avatar
zhuwenwen committed
200
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
201
  if (threadIdx.x < ngpus) {
zhuwenwen's avatar
zhuwenwen committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
    // Write the expected counter value to peer and wait for correct value
    // from peer.
    st_flag_volatile(peer_counter_ptr, flag);
    while (ld_flag_volatile(self_counter_ptr) != flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
219
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
zhuwenwen's avatar
zhuwenwen committed
220
221
222
223
224
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
225
226
    // Write the expected counter value to peer and wait for correct value from
    // peer.
zhuwenwen's avatar
zhuwenwen committed
227
228
229
    if constexpr (!final_sync) {
      st_flag_release(peer_counter_ptr, flag);
      while (ld_flag_acquire(self_counter_ptr) != flag);
230
    } else {
zhuwenwen's avatar
zhuwenwen committed
231
232
      st_flag_volatile(peer_counter_ptr, flag);
      while (ld_flag_volatile(self_counter_ptr) != flag);
233
    }
234
  }
zhuwenwen's avatar
zhuwenwen committed
235
236
237
238
239
240
241
242
243
  if constexpr (!final_sync) __syncthreads();

  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

#else

template <int ngpus>
244
245
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
zhuwenwen's avatar
zhuwenwen committed
246
247
248
249
250
251
252
253
254
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
    //                         flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
    __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
      __ATOMIC_RELAXED);
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
255
256
257
    // while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
    //                               __ATOMIC_RELAXED,
    //                               __MEMORY_SCOPE_DEVICE) < flag);
zhuwenwen's avatar
zhuwenwen committed
258
259
260
261
262
263
264
265
266
    while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
      __ATOMIC_RELAXED) < flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

template <int ngpus, bool final_sync = false>
267
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
zhuwenwen's avatar
zhuwenwen committed
268
269
270
271
272
273
274
275
276
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
    //                         flag,
    //                         final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
    //                         __MEMORY_SCOPE_SYSTEM);
zhuwenwen's avatar
zhuwenwen committed
277
278
    __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
      final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
279
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
280
281
282
283
284
285
286
287
288
289
290
    // while (
    //     __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
    //                            final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
    //                            __MEMORY_SCOPE_DEVICE) < flag);
    while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
    flag);
  }
  if constexpr (!final_sync) __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
291
292
}

zhuwenwen's avatar
zhuwenwen committed
293
294
#endif

295
template <typename P, int ngpus, typename A>
296
DINLINE P packed_reduce(const P* ptrs[], int idx) {
297
298
299
300
301
302
303
304
305
306
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
307
308
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
309
310
311
312
313
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
314
  barrier_at_start<ngpus>(sg, self_sg, rank);
315
316
317
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
318
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
319
  }
320
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
321
322
323
}

template <typename P>
324
DINLINE P* get_tmp_buf(Signal* sg) {
325
  return (P*)(((Signal*)sg) + 1);
326
327
328
329
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
330
331
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
332
333
334
335
336
337
338
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
339
  int largest_part = part + size % ngpus;
340
341
  const P* ptrs[ngpus];
  P* tmps[ngpus];
342
343
344
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
345
    ptrs[i] = (const P*)_dp->ptrs[target];
346
347
348
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
349
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
350

351
352
353
354
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
355
  barrier_at_end<ngpus>(sg, self_sg, rank);
356
357
358
359

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
360
361
362
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

363
  for (int idx = tid; idx < largest_part; idx += stride) {
364
365
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
366
367
368
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
369
        ((P*)result)[dst_idx] = tmps[i][idx];
370
      }
371
372
373
374
    }
  }
}

zhuwenwen's avatar
zhuwenwen committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
    cross_device_reduce_1stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size,
                               uint32_t** curr_hdp_reg, int world_size) {

  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
  if (threadIdx.x == 1) {
    for(int i = 0; i < world_size; i++) {
      __atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
    }
  }
391
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
392
393
394
395
396
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
  }
397
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
    cross_device_reduce_2stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size,
                               uint32_t** curr_hdp_reg, int world_size) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
  int largest_part = part + size % ngpus;
  const P* ptrs[ngpus];
  P* tmps[ngpus];
  if (threadIdx.x == 1) {
    for(int i = 0; i < world_size; i++) {
      __atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
    }
  }
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
    ptrs[i] = (const P*)_dp->ptrs[target];
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
427
  barrier_at_start<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
428
429
430
431
432

  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
433
  barrier_at_end<ngpus>(sg, self_sg, rank);
zhuwenwen's avatar
zhuwenwen committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

  for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
        ((P*)result)[dst_idx] = tmps[i][idx];
      }
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
453
454
455
456
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

457
458
459
460
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
461
462
  // Full NVLink or xGMI connection between GPUs.
  bool fully_connected_;
463
464

  RankSignals sg_;
Tianer Zhou's avatar
Tianer Zhou committed
465
  // Stores a map from a pointer to its peer pointers from all ranks.
466
467
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
468

469
470
  // Stores rank data from all ranks. This is mainly for cuda graph purposes.
  // For cuda graph to work, all kernel arguments must be fixed during graph
471
472
473
474
475
476
  // capture time. However, the peer pointers are not known during graph
  // capture time. Therefore, during capture, we increment the rank data
  // pointer and use that as the argument to the kernel. The kernel arguments
  // are stored in graph_unreg_buffers_. The actual peer pointers will be
  // filled in at the memory pointed to by the pointers in
  // graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
477
478
479
480
481
482
483
484
  //
  // The overall process looks like this:
  // 1. Graph capture.
  // 2. Each rank obtains the IPC handles for each addresses used during cuda
  // graph capture using get_graph_buffer_ipc_meta.
  // 3. (In Python) all gather the IPC handles.
  // 4. Obtain the peer pointers by opening the IPC handles, and store them in
  // the rank data array at corresponding positions.
485
  RankData *d_rank_data_base_, *d_rank_data_end_;
486
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
487
  // a map from IPC handles to opened IPC pointers
488
  std::map<IPC_KEY, char*> ipc_handles_;
489

zhuwenwen's avatar
zhuwenwen committed
490
  uint32_t** dev_curr_hdp_reg;
491
  /**
492
493
494
   * Signals are an array of ipc-enabled buffers from all ranks.
   * For each of the buffer, the layout is as follows:
   * | -- sizeof(Signal) -- | ------ a few MB ----- |
495
496
497
   * The first section is for allreduce synchronization, and the second
   * section is for storing the intermediate results required by some
   * allreduce algos.
498
   *
499
500
   * Note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor.
501
   */
502
  CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
503
                  int rank, int world_size, bool fully_connected = true)
504
      : rank_(rank),
505
        world_size_(world_size),
506
        fully_connected_(fully_connected),
507
        self_sg_(signals[rank]),
508
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
509
510
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
511
      sg_.signals[i] = signals[i];
512
    }
zhuwenwen's avatar
zhuwenwen committed
513
514
515
516
517
518
    if (!fully_connected) {
      cudaMalloc((void**)&dev_curr_hdp_reg, world_size_ * sizeof(uint32_t*));
      for (int i = 0; i < world_size_; ++i) {
        hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i);
      }
    }
519
520
  }

521
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
522
    auto [it, new_handle] =
523
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
524
    if (new_handle) {
525
526
527
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
528
529
530
531
532
533
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

534
  std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
535
536
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
537
    std::string handles(handle_sz * num_buffers, static_cast<char>(0));
538
539
540
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
541
      void* base_ptr;
542
543
      // note: must share the base address of each allocation, or we get wrong
      // address
zhuwenwen's avatar
zhuwenwen committed
544
      if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
545
546
547
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
548
549
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
550
551
552
553
554
555
556
557
558
559
560
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

561
562
563
564
  /**
   * Register already-shared IPC pointers.
   */
  void register_buffer(void** ptrs) {
565
566
567
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
568
      data.ptrs[i] = ptrs[i];
569
570
571
572
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
573
    buffers_[ptrs[rank_]] = d_data;
574
575
  }

576
  // Note: when registering graph buffers, we intentionally choose to not
577
  // deduplicate the addresses. That means if the allocator reuses some
578
579
580
581
582
  // addresses, they will be registered again. This is to account for the
  // remote possibility of different allocation patterns between ranks. For
  // example, rank 1 may get the same input address for the second allreduce,
  // but rank 2 got a different address. IPC handles have internal reference
  // counting mechanism so overhead should be small.
583
  void register_graph_buffers(
584
585
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
586
587
588
589
590
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
591
      auto& rd = rank_data[i];
592
593
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
594
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
595
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
596
597
598
599
600
601
602
603
604
605
606
607
608
609
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

zhuwenwen's avatar
zhuwenwen committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
  template <typename T>
  void allreduce_pcie(cudaStream_t stream, T* input, T* output, int size,
                 int threads = 512, int block_limit = defaultBlockLimit) {
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));

    RankData* ptrs;
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                  rank_, size, dev_curr_hdp_reg, world_size_) ;

#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage_pcie);     \
    } else {                                          \
      if ((world_size_ <= 4 && bytes < 128 * 8192) || \
          (world_size_ <= 8 && bytes < 8 * 8192)) {  \
        KL(ngpus, cross_device_reduce_1stage_pcie);   \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage_pcie);   \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      REDUCE_CASE(16)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
            "num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

679
  /**
680
681
   * Performs allreduce, assuming input has already been registered.
   *
682
683
684
685
686
   * Block and grid default configs are results after careful grid search.
   * Using 36 blocks give the best or close to the best runtime on the devices
   * I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
   * only take a small amount of SMs. Not quite sure the underlying reason,
   * but my guess is that too many SMs will cause contention on NVLink bus.
687
688
   */
  template <typename T>
689
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
zhuwenwen's avatar
zhuwenwen committed
690
                 int threads = 512, int block_limit = defaultBlockLimit) {
691
692
693
694
695
696
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
697
698
699
700
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
701

702
    RankData* ptrs;
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
721
722
723
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
724
725
726
727
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage);          \
728
    } else if (fully_connected_) {                    \
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
zhuwenwen's avatar
zhuwenwen committed
746
747
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
748
749
750
751
752
753
754
755
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
756
    for (auto [_, ptr] : ipc_handles_) {
757
758
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
zhuwenwen's avatar
zhuwenwen committed
759
    cudaFree(dev_curr_hdp_reg);
760
761
  }
};
762

763
/**
764
765
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and
 add a template instantiation:
766
767
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
768
*/
769
}  // namespace vllm