custom_all_reduce.cuh 21.2 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

8
9
10
11
#if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16;
#endif

12
#include <iostream>
13
#include <array>
14
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
15
#include <map>
16
17
18
#include <unordered_map>
#include <vector>

19
namespace vllm {
20
21
22
23
24
25
26
27
28
29
#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

30
// Maximal number of blocks in allreduce kernel.
31
constexpr int kMaxBlocks = 36;
32
33
34
35
36
37
38
39
40
41
42

// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
    HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif

43
44
45
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
46
47
48
49
50
51
52

// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
53
struct Signal {
54
55
56
  alignas(128) FlagType start[kMaxBlocks][8];
  alignas(128) FlagType end[kMaxBlocks][8];
  alignas(128) FlagType _flag[kMaxBlocks];  // incremental flags for each rank
57
58
};

59
struct __align__(16) RankData {
60
  const void* ptrs[8];
61
};
62

63
64
65
struct __align__(16) RankSignals {
  Signal* signals[8];
};
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
100
DINLINE half& assign_add(half& a, half b) {
101
102
103
  a = __hadd(a, b);
  return a;
}
104
DINLINE float& assign_add(float& a, float b) { return a += b; }
105
106
107
108
109
110
111

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
112
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
113
114
115
116
117
118
  a = __hadd(a, b);
  return a;
}
#endif

template <typename T, int N>
119
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

155
156
#if !defined(USE_ROCM)

157
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
158
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
159
160
  asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
161
  #else
162
163
  asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
164
  #endif
165
166
167
168
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  FlagType flag;
169
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
170
171
172
  asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
173
  #else
174
175
176
  asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
               : "=r"(flag)
               : "l"(flag_addr));
177
  #endif
178
179
180
181
182
183
184
185
186
187
188
189
190
  return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  FlagType flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
  return flag;
191
192
}

193
194
195
196
197
198
199
200
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
201
  if (threadIdx.x < ngpus) {
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
    // Write the expected counter value to peer and wait for correct value
    // from peer.
    st_flag_volatile(peer_counter_ptr, flag);
    while (ld_flag_volatile(self_counter_ptr) != flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
225
226
    // Write the expected counter value to peer and wait for correct value from
    // peer.
227
228
229
    if constexpr (!final_sync) {
      st_flag_release(peer_counter_ptr, flag);
      while (ld_flag_acquire(self_counter_ptr) != flag);
230
    } else {
231
232
      st_flag_volatile(peer_counter_ptr, flag);
      while (ld_flag_volatile(self_counter_ptr) != flag);
233
    }
234
  }
235
236
237
238
  if constexpr (!final_sync) __syncthreads();

  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
239
240
}

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#else

template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
                            flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
    // wait until we got true from all ranks
    while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
                                  __ATOMIC_RELAXED,
                                  __MEMORY_SCOPE_DEVICE) < flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
                            flag,
                            final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
                            __MEMORY_SCOPE_SYSTEM);
    // wait until we got true from all ranks
    while (
        __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                               final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
                               __MEMORY_SCOPE_DEVICE) < flag);
  }
  if constexpr (!final_sync) __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

#endif

286
template <typename P, int ngpus, typename A>
287
DINLINE P packed_reduce(const P* ptrs[], int idx) {
288
289
290
291
292
293
294
295
296
297
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
298
299
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
300
301
302
303
304
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
305
  barrier_at_start<ngpus>(sg, self_sg, rank);
306
307
308
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
309
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
310
  }
311
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
312
313
314
}

template <typename P>
315
DINLINE P* get_tmp_buf(Signal* sg) {
316
  return (P*)(((Signal*)sg) + 1);
317
318
319
320
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
321
322
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
323
324
325
326
327
328
329
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
330
  int largest_part = part + size % ngpus;
331
332
  const P* ptrs[ngpus];
  P* tmps[ngpus];
333
334
335
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
336
    ptrs[i] = (const P*)_dp->ptrs[target];
337
338
339
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
340
341
  barrier_at_start<ngpus>(sg, self_sg, rank);

342
343
344
345
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
346
  barrier_at_end<ngpus>(sg, self_sg, rank);
347
348
349
350

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
351
352
353
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

354
  for (int idx = tid; idx < largest_part; idx += stride) {
355
356
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
357
358
359
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
360
        ((P*)result)[dst_idx] = tmps[i][idx];
361
      }
362
363
364
365
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
366
367
368
369
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

370
371
372
373
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
374
375
  // Full NVLink or xGMI connection between GPUs.
  bool fully_connected_;
376
377

  RankSignals sg_;
Tianer Zhou's avatar
Tianer Zhou committed
378
  // Stores a map from a pointer to its peer pointers from all ranks.
379
380
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
381

382
383
  // Stores rank data from all ranks. This is mainly for cuda graph purposes.
  // For cuda graph to work, all kernel arguments must be fixed during graph
384
385
386
387
388
389
  // capture time. However, the peer pointers are not known during graph
  // capture time. Therefore, during capture, we increment the rank data
  // pointer and use that as the argument to the kernel. The kernel arguments
  // are stored in graph_unreg_buffers_. The actual peer pointers will be
  // filled in at the memory pointed to by the pointers in
  // graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
390
391
392
393
394
395
396
397
  //
  // The overall process looks like this:
  // 1. Graph capture.
  // 2. Each rank obtains the IPC handles for each addresses used during cuda
  // graph capture using get_graph_buffer_ipc_meta.
  // 3. (In Python) all gather the IPC handles.
  // 4. Obtain the peer pointers by opening the IPC handles, and store them in
  // the rank data array at corresponding positions.
398
  RankData *d_rank_data_base_, *d_rank_data_end_;
399
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
400
  // a map from IPC handles to opened IPC pointers
401
  std::map<IPC_KEY, char*> ipc_handles_;
402
403

  /**
404
405
406
   * Signals are an array of ipc-enabled buffers from all ranks.
   * For each of the buffer, the layout is as follows:
   * | -- sizeof(Signal) -- | ------ a few MB ----- |
407
408
409
   * The first section is for allreduce synchronization, and the second
   * section is for storing the intermediate results required by some
   * allreduce algos.
410
   *
411
412
   * Note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor.
413
   */
414
  CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
415
                  int rank, int world_size, bool fully_connected = true)
416
      : rank_(rank),
417
        world_size_(world_size),
418
        fully_connected_(fully_connected),
419
        self_sg_(signals[rank]),
420
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
421
422
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
423
      sg_.signals[i] = signals[i];
424
425
426
    }
  }

427
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
428
    auto [it, new_handle] =
429
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
430
    if (new_handle) {
431
432
433
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
434
435
436
437
438
439
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

440
  std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
441
442
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
443
    std::string handles(handle_sz * num_buffers, static_cast<char>(0));
444
445
446
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
447
      void* base_ptr;
448
449
      // note: must share the base address of each allocation, or we get wrong
      // address
450
      if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
451
452
453
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
454
455
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
456
457
458
459
460
461
462
463
464
465
466
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

467
468
469
470
  /**
   * Register already-shared IPC pointers.
   */
  void register_buffer(void** ptrs) {
471
472
473
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
474
      data.ptrs[i] = ptrs[i];
475
476
477
478
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
479
    buffers_[ptrs[rank_]] = d_data;
480
481
  }

482
  // Note: when registering graph buffers, we intentionally choose to not
483
  // deduplicate the addresses. That means if the allocator reuses some
484
485
486
487
488
  // addresses, they will be registered again. This is to account for the
  // remote possibility of different allocation patterns between ranks. For
  // example, rank 1 may get the same input address for the second allreduce,
  // but rank 2 got a different address. IPC handles have internal reference
  // counting mechanism so overhead should be small.
489
  void register_graph_buffers(
490
491
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
492
493
494
495
496
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
497
      auto& rd = rank_data[i];
498
499
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
500
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
501
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

  /**
517
518
   * Performs allreduce, assuming input has already been registered.
   *
519
520
521
522
523
   * Block and grid default configs are results after careful grid search.
   * Using 36 blocks give the best or close to the best runtime on the devices
   * I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
   * only take a small amount of SMs. Not quite sure the underlying reason,
   * but my guess is that too many SMs will cause contention on NVLink bus.
524
525
   */
  template <typename T>
526
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
527
                 int threads = 512, int block_limit = defaultBlockLimit) {
528
529
530
531
532
533
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
534
535
536
537
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
538

539
    RankData* ptrs;
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
558
559
560
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
561
562
563
564
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage);          \
565
    } else if (fully_connected_) {                    \
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
583
584
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
585
586
587
588
589
590
591
592
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
593
    for (auto [_, ptr] : ipc_handles_) {
594
595
596
597
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
  }
};
598

599
/**
600
601
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and
 add a template instantiation:
602
603
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
604
*/
605
}  // namespace vllm