custom_all_reduce.cuh 22.8 KB
Newer Older
1
2
3
4
5
6
7
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

8
9
10
11
#if defined(USE_ROCM)
typedef __hip_bfloat16 nv_bfloat16;
#endif

12
#include <iostream>
13
#include <array>
14
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
15
#include <map>
16
17
#include <unordered_map>
#include <vector>
18
19
#include <cstdlib>
#include <cstring>
20

21
namespace vllm {
22
23
24
25
26
27
28
29
30
31
#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

32
// Maximal number of blocks in allreduce kernel.
33
constexpr int kMaxBlocks = 36;
34
35
36
37
38
39
40
41
42
43
44

// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const int defaultBlockLimit = 36;
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#else
const int defaultBlockLimit = 16;
hipPointer_attribute rangeStartAddrAttr =
    HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
#endif

45
46
47
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
48
49
50
51
52
53
54

// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
55
struct Signal {
56
57
58
  alignas(128) FlagType start[kMaxBlocks][8];
  alignas(128) FlagType end[kMaxBlocks][8];
  alignas(128) FlagType _flag[kMaxBlocks];  // incremental flags for each rank
59
60
};

61
struct __align__(16) RankData {
62
  const void* ptrs[8];
63
};
64

65
66
67
struct __align__(16) RankSignals {
  Signal* signals[8];
};
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
102
DINLINE half& assign_add(half& a, half b) {
103
104
105
  a = __hadd(a, b);
  return a;
}
106
DINLINE float& assign_add(float& a, float b) { return a += b; }
107
108
109
110
111
112
113

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
114
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
115
116
117
118
119
120
  a = __hadd(a, b);
  return a;
}
#endif

template <typename T, int N>
121
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

157
158
#if !defined(USE_ROCM)

159
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
160
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
161
162
  asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
163
  #else
164
165
  asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
166
  #endif
167
168
169
170
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  FlagType flag;
171
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
172
173
174
  asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
175
  #else
176
177
178
  asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
               : "=r"(flag)
               : "l"(flag_addr));
179
  #endif
180
181
182
183
184
185
186
187
188
189
190
191
192
  return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  FlagType flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
  return flag;
193
194
}

195
196
197
198
199
200
201
202
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
203
  if (threadIdx.x < ngpus) {
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
    // Write the expected counter value to peer and wait for correct value
    // from peer.
    st_flag_volatile(peer_counter_ptr, flag);
    while (ld_flag_volatile(self_counter_ptr) != flag);
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
    auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
227
228
    // Write the expected counter value to peer and wait for correct value from
    // peer.
229
230
231
    if constexpr (!final_sync) {
      st_flag_release(peer_counter_ptr, flag);
      while (ld_flag_acquire(self_counter_ptr) != flag);
232
    } else {
233
234
      st_flag_volatile(peer_counter_ptr, flag);
      while (ld_flag_volatile(self_counter_ptr) != flag);
235
    }
236
  }
237
238
239
240
  if constexpr (!final_sync) __syncthreads();

  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
241
242
}

243
244
245
246
247
248
249
250
251
#else

template <int ngpus>
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
                              int rank) {
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
zhuwenwen's avatar
zhuwenwen committed
252
253
254
255
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
    //                         flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
    __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag,
      __ATOMIC_RELAXED);
256
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
261
    // while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
    //                               __ATOMIC_RELAXED,
    //                               __MEMORY_SCOPE_DEVICE) < flag);
    while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
      __ATOMIC_RELAXED) < flag);
262
263
264
265
266
267
268
269
270
271
272
273
274
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

template <int ngpus, bool final_sync = false>
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
  __syncthreads();
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
zhuwenwen's avatar
zhuwenwen committed
275
276
277
278
279
280
    // __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
    //                         flag,
    //                         final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
    //                         __MEMORY_SCOPE_SYSTEM);
    __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
      final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE);
281
    // wait until we got true from all ranks
zhuwenwen's avatar
zhuwenwen committed
282
283
284
285
286
287
288
    // while (
    //     __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
    //                            final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
    //                            __MEMORY_SCOPE_DEVICE) < flag);
    while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) <
    flag);
289
290
291
292
293
294
295
296
  }
  if constexpr (!final_sync) __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
}

#endif

297
template <typename P, int ngpus, typename A>
298
DINLINE P packed_reduce(const P* ptrs[], int idx) {
299
300
301
302
303
304
305
306
307
308
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
309
310
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
311
312
313
314
315
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
316
  barrier_at_start<ngpus>(sg, self_sg, rank);
317
318
319
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
320
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
321
  }
322
  barrier_at_end<ngpus, true>(sg, self_sg, rank);
323
324
325
}

template <typename P>
326
DINLINE P* get_tmp_buf(Signal* sg) {
327
  return (P*)(((Signal*)sg) + 1);
328
329
330
331
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
332
333
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
334
335
336
337
338
339
340
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
341
  int largest_part = part + size % ngpus;
342
343
  const P* ptrs[ngpus];
  P* tmps[ngpus];
344
345
346
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
347
    ptrs[i] = (const P*)_dp->ptrs[target];
348
349
350
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
351
352
  barrier_at_start<ngpus>(sg, self_sg, rank);

353
354
355
356
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
357
  barrier_at_end<ngpus>(sg, self_sg, rank);
358
359
360
361

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
362
363
364
  // start + i in the first stage, then thread i also gathers start + i from
  // all ranks.

365
  for (int idx = tid; idx < largest_part; idx += stride) {
366
367
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
368
369
370
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
371
        ((P*)result)[dst_idx] = tmps[i][idx];
372
      }
373
374
375
376
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
377
378
379
380
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

381
382
383
384
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
385
386
  // Full NVLink or xGMI connection between GPUs.
  bool fully_connected_;
387
388

  RankSignals sg_;
Tianer Zhou's avatar
Tianer Zhou committed
389
  // Stores a map from a pointer to its peer pointers from all ranks.
390
391
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
392

393
394
  // Stores rank data from all ranks. This is mainly for cuda graph purposes.
  // For cuda graph to work, all kernel arguments must be fixed during graph
395
396
397
398
399
400
  // capture time. However, the peer pointers are not known during graph
  // capture time. Therefore, during capture, we increment the rank data
  // pointer and use that as the argument to the kernel. The kernel arguments
  // are stored in graph_unreg_buffers_. The actual peer pointers will be
  // filled in at the memory pointed to by the pointers in
  // graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
401
402
403
404
405
406
407
408
  //
  // The overall process looks like this:
  // 1. Graph capture.
  // 2. Each rank obtains the IPC handles for each addresses used during cuda
  // graph capture using get_graph_buffer_ipc_meta.
  // 3. (In Python) all gather the IPC handles.
  // 4. Obtain the peer pointers by opening the IPC handles, and store them in
  // the rank data array at corresponding positions.
409
  RankData *d_rank_data_base_, *d_rank_data_end_;
410
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
411
  // a map from IPC handles to opened IPC pointers
412
  std::map<IPC_KEY, char*> ipc_handles_;
413
414

  /**
415
416
417
   * Signals are an array of ipc-enabled buffers from all ranks.
   * For each of the buffer, the layout is as follows:
   * | -- sizeof(Signal) -- | ------ a few MB ----- |
418
419
420
   * The first section is for allreduce synchronization, and the second
   * section is for storing the intermediate results required by some
   * allreduce algos.
421
   *
422
423
   * Note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor.
424
   */
425
  CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
426
                  int rank, int world_size, bool fully_connected = true)
427
      : rank_(rank),
428
        world_size_(world_size),
429
        fully_connected_(fully_connected),
430
        self_sg_(signals[rank]),
431
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
432
433
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
434
      sg_.signals[i] = signals[i];
435
436
437
    }
  }

438
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
439
    auto [it, new_handle] =
440
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
441
    if (new_handle) {
442
443
444
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
445
446
447
448
449
450
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

451
  std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
452
453
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
454
    std::string handles(handle_sz * num_buffers, static_cast<char>(0));
455
456
457
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
458
      void* base_ptr;
459
460
      // note: must share the base address of each allocation, or we get wrong
      // address
461
      if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
462
463
464
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
465
466
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
467
468
469
470
471
472
473
474
475
476
477
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

478
479
480
481
  /**
   * Register already-shared IPC pointers.
   */
  void register_buffer(void** ptrs) {
482
483
484
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
485
      data.ptrs[i] = ptrs[i];
486
487
488
489
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
490
    buffers_[ptrs[rank_]] = d_data;
491
492
  }

493
  // Note: when registering graph buffers, we intentionally choose to not
494
  // deduplicate the addresses. That means if the allocator reuses some
495
496
497
498
499
  // addresses, they will be registered again. This is to account for the
  // remote possibility of different allocation patterns between ranks. For
  // example, rank 1 may get the same input address for the second allreduce,
  // but rank 2 got a different address. IPC handles have internal reference
  // counting mechanism so overhead should be small.
500
  void register_graph_buffers(
501
502
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
503
504
505
506
507
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
508
      auto& rd = rank_data[i];
509
510
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
511
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
512
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

  /**
528
529
   * Performs allreduce, assuming input has already been registered.
   *
530
531
532
533
534
   * Block and grid default configs are results after careful grid search.
   * Using 36 blocks give the best or close to the best runtime on the devices
   * I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
   * only take a small amount of SMs. Not quite sure the underlying reason,
   * but my guess is that too many SMs will cause contention on NVLink bus.
535
536
   */
  template <typename T>
537
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
538
                 int threads = 512, int block_limit = defaultBlockLimit) {
539
540
541
542
543
544
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
545
546
547
548
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
549

550
    RankData* ptrs;
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587

    // Check environment variable once
    const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO");
    bool force_1stage = false;
    bool force_2stage = false;
    if (env_algo != nullptr) {
      if (std::strcmp(env_algo, "1stage") == 0 ||
          std::strcmp(env_algo, "oneshot") == 0) {
        force_1stage = true;
      } else if (std::strcmp(env_algo, "2stage") == 0 ||
                 std::strcmp(env_algo, "twoshot") == 0) {
        force_2stage = true;
      } else {
        throw std::runtime_error(
            "Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) +
            ". Valid values: 1stage, oneshot, 2stage, twoshot");
      }
    }

588
589
590
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
#define REDUCE_CASE(ngpus)                              \
  case ngpus: {                                         \
    if (force_1stage) {                                 \
      KL(ngpus, cross_device_reduce_1stage);            \
    } else if (force_2stage) {                          \
      KL(ngpus, cross_device_reduce_2stage);            \
    } else {                                            \
      if (world_size_ == 2) {                           \
        KL(ngpus, cross_device_reduce_1stage);          \
      } else if (fully_connected_) {                    \
        if ((world_size_ <= 4 && bytes < 512 * 1024) || \
            (world_size_ <= 8 && bytes < 256 * 1024)) { \
          KL(ngpus, cross_device_reduce_1stage);        \
        } else {                                        \
          KL(ngpus, cross_device_reduce_2stage);        \
        }                                               \
      }                                                 \
    }                                                   \
    break;                                              \
610
611
612
613
614
615
616
617
618
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
619
620
            "custom allreduce only supports num gpus in (2,4,6,8). Actual "
            "num "
621
622
623
624
625
626
627
628
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
629
    for (auto [_, ptr] : ipc_handles_) {
630
631
632
633
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
  }
};
634

635
/**
636
637
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and
 add a template instantiation:
638
639
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
640
*/
641
}  // namespace vllm