custom_all_reduce.cuh 17.7 KB
Newer Older
1
2
3
4
5
6
7
8
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <iostream>
9
#include <array>
10
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
11
#include <map>
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <unordered_map>
#include <vector>

#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

namespace vllm {

27
28
29
30
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
31
struct Signal {
32
33
34
35
36
37
38
  alignas(128) FlagType self_counter[kMaxBlocks][8];
  // Two sets of peer counters are needed for two syncs. The reason is that
  // it's possible for peer GPU block to arrive at the second sync point while
  // the current GPU block haven't passed the first sync point. Thus, peer GPU
  // may write counter+1 while current GPU is busy waiting for counter. We use
  // alternating counter array to avoid this possibility.
  alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
39
40
};

41
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
42

43
struct __align__(16) RankSignals { Signal* signals[8]; };
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
78
DINLINE half& assign_add(half& a, half b) {
79
80
81
  a = __hadd(a, b);
  return a;
}
82
DINLINE float& assign_add(float& a, float b) { return a += b; }
83
84
85
86
87
88
89

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
90
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
91
92
93
94
95
96
  a = __hadd(a, b);
  return a;
}
#endif

template <typename T, int N>
97
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

133
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
134
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
135
136
  asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
137
138
139
140
#else
  asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
               "l"(flag_addr));
#endif
141
142
143
144
}

static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
  FlagType flag;
145
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
146
147
148
  asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
149
150
151
152
153
#else
  asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
               : "=r"(flag)
               : "l"(flag_addr));
#endif
154
155
156
157
158
159
160
161
162
163
164
165
166
  return flag;
}

static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
  FlagType flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];"
               : "=r"(flag)
               : "l"(flag_addr));
  return flag;
167
168
}

169
170
171
172
173
174
175
176
177
178
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
                               int rank) {
  if constexpr (!is_start) __syncthreads();
  static_assert(
      !(is_start && need_fence));  // Start barrier shouldn't need fence.
179
  if (threadIdx.x < ngpus) {
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    // Increment the counter. Technically we only need one counter, but we use
    // multiple per block to eliminate the need to share the counter via smem.
    auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
    // Write the expected counter value to peer and wait for correct value from
    // peer.
    auto peer_counter_ptr =
        &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
    auto self_counter_ptr =
        &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
    if constexpr (need_fence) {
      st_flag_release(peer_counter_ptr, val);
      while (ld_flag_acquire(self_counter_ptr) != val);
    } else {
      st_flag_volatile(peer_counter_ptr, val);
      while (ld_flag_volatile(self_counter_ptr) != val);
    }
196
  }
197
  if constexpr (is_start || need_fence) __syncthreads();
198
199
200
}

template <typename P, int ngpus, typename A>
201
DINLINE P packed_reduce(const P* ptrs[], int idx) {
202
203
204
205
206
207
208
209
210
211
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
212
213
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
214
215
216
217
218
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
219
  multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
220
221
222
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
223
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
224
  }
225
  multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
226
227
228
}

template <typename P>
229
DINLINE P* get_tmp_buf(Signal* sg) {
230
  return (P*)(((Signal*)sg) + 1);
231
232
233
234
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
235
236
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
                               T* __restrict__ result, int rank, int size) {
237
238
239
240
241
242
243
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
244
  int largest_part = part + size % ngpus;
245
246
  const P* ptrs[ngpus];
  P* tmps[ngpus];
247
248
249
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
250
    ptrs[i] = (const P*)_dp->ptrs[target];
251
252
253
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
254
  multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
255
256
257
258
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
259
  multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
260
261
262
263
264
265
266

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
  // start + i in the first stage, then thread i also gathers start + i from all
  // ranks.
  for (int idx = tid; idx < largest_part; idx += stride) {
267
268
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
269
270
271
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
272
        ((P*)result)[dst_idx] = tmps[i][idx];
273
      }
274
275
276
277
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
278
279
280
281
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

282
283
284
285
286
287
288
289
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
  bool full_nvlink_;

  // below are device pointers
  RankSignals sg_;
290
291
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
292
293
294

  // stores the registered device pointers from all ranks
  RankData *d_rank_data_base_, *d_rank_data_end_;
295
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
296
  // a map from IPC handles to opened IPC pointers
297
  std::map<IPC_KEY, char*> ipc_handles_;
298
299
300
301

  /**
   * meta is a pointer to device metadata and temporary buffer for allreduce.
   *
302
   * There's a total of sizeof(Signal) of prefix before the actual data,
303
304
305
306
307
   * so meta + 1 points to actual temporary buffer.
   *
   * note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor
   */
308
309
310
  CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
                  const cudaIpcMemHandle_t* handles,
                  const std::vector<int64_t>& offsets, int rank,
311
312
313
314
                  bool full_nvlink = true)
      : rank_(rank),
        world_size_(offsets.size()),
        full_nvlink_(full_nvlink),
315
        self_sg_(meta),
316
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
317
318
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
319
      Signal* rank_sg;
320
      if (i != rank_) {
321
        char* handle = open_ipc_handle(&handles[i]);
322
        handle += offsets[i];
323
        rank_sg = (Signal*)handle;
324
      } else {
325
        rank_sg = self_sg_;
326
      }
327
      sg_.signals[i] = rank_sg;
328
329
330
    }
  }

331
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
332
    auto [it, new_handle] =
333
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
334
    if (new_handle) {
335
336
337
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
338
339
340
341
342
343
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

344
345
346
347
348
349
350
351
  std::pair<std::vector<uint8_t>, std::vector<int64_t>>
  get_graph_buffer_ipc_meta() {
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
    std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
352
      void* base_ptr;
353
354
355
356
357
358
359
      // note: must share the base address of each allocation, or we get wrong
      // address
      if (cuPointerGetAttribute(&base_ptr,
                                CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
360
361
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
362
363
364
365
366
367
368
369
370
371
372
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

373
374
  void register_buffer(const std::vector<std::string>& handles,
                       const std::vector<int64_t>& offsets, void* self) {
375
376
377
378
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
      if (i != rank_) {
379
        char* handle = open_ipc_handle(handles[i].data());
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        handle += offsets[i];
        data.ptrs[i] = handle;
      } else {
        data.ptrs[i] = self;
      }
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
    buffers_[self] = d_data;
  }

  // note: when registering graph buffers, we intentionally choose to not
  // deduplicate the addresses. That means if the allocator reuses some
  // addresses, they will be registered again. This is to account for the remote
  // possibility of different allocation patterns between ranks. For example,
  // rank 1 may get the same input address for the second allreduce, but rank 2
  // got a different address. IPC handles have internal reference counting
  // mechanism so overhead should be small.
  void register_graph_buffers(
400
401
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
402
403
404
405
406
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
407
      auto& rd = rank_data[i];
408
409
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
410
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
411
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

  /**
   * This is the result after careful grid search. Using 36 blocks give the best
   * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
   * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
   * Not quite sure the underlying reason, but my guess is that too many SMs
   * will cause contention on NVLink bus.
   */
  template <typename T>
434
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
435
436
437
438
439
440
441
                 int threads = 512, int block_limit = 36) {
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
442
443
444
445
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
446

447
    RankData* ptrs;
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
466
467
468
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
469
470
    // TODO(hanzhi713): Threshold is different for A100 and H100.
    // Add per device threshold.
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage);          \
    } else if (full_nvlink_) {                        \
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
502
    for (auto [_, ptr] : ipc_handles_) {
503
504
505
506
507
508
509
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
  }
};
/**
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 a template instantiation:
510
511
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
512
513
*/
}  // namespace vllm