custom_all_reduce.cuh 16.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <iostream>
#include <limits>
Hanzhi Zhou's avatar
Hanzhi Zhou committed
10
#include <map>
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <unordered_map>
#include <vector>

#define CUDACHECK(cmd)                                              \
  do {                                                              \
    cudaError_t e = cmd;                                            \
    if (e != cudaSuccess) {                                         \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
             cudaGetErrorString(e));                                \
      exit(EXIT_FAILURE);                                           \
    }                                                               \
  } while (0)

namespace vllm {

26
27
28
constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
29
struct Signal {
30
31
  alignas(128) uint32_t start[kMaxBlocks][8];
  alignas(128) uint32_t end[kMaxBlocks][8];
32
33
};

34
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
35

36
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
71
DINLINE half& assign_add(half& a, half b) {
72
73
74
  a = __hadd(a, b);
  return a;
}
75
DINLINE float& assign_add(float& a, float b) { return a += b; }
76
77
78
79
80
81
82

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
83
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
84
85
86
87
88
89
  a = __hadd(a, b);
  return a;
}
#endif

template <typename T, int N>
90
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

126
127
128
129
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
130
template <int ngpus>
131
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
132
                        int rank) {
133
134
135
136
137
138
139
  if (threadIdx.x < ngpus) {
    // reset flag for next time
    self_sg->end[blockIdx.x][threadIdx.x] = 0;
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
    // wait until we got true from all ranks
140
    while (!self_sg->start[blockIdx.x][threadIdx.x]);
141
142
143
144
  }
  __syncthreads();
}

145
146
147
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
148
template <int ngpus, bool final_sync = false>
149
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
150
151
                      int rank) {
  __syncthreads();
152
153
154
  // eliminate the case that prior writes are not visible after signals become
  // visible. Note that I did not managed to make this happen through a lot of
  // testing. Might be the case that hardware provides stronger guarantee than
155
  // the memory model.
156
157
158
159
160
161
162
163
  if constexpr (!final_sync) __threadfence_system();
  if (threadIdx.x < ngpus) {
    // reset flag for next time
    self_sg->start[blockIdx.x][threadIdx.x] = 0;
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
    // wait until we got true from all ranks
164
    while (!self_sg->end[blockIdx.x][threadIdx.x]);
165
  }
166
  if constexpr (!final_sync) __syncthreads();
167
168
169
}

template <typename P, int ngpus, typename A>
170
DINLINE P packed_reduce(const P* ptrs[], int idx) {
171
172
173
174
175
176
177
178
179
180
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
181
182
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
                               volatile Signal* self_sg, T* __restrict__ result,
183
184
185
186
187
188
                               int rank, int size) {
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
189
  start_sync<ngpus>(sg, self_sg, rank);
190
191
192
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += gridDim.x * blockDim.x) {
193
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
194
  }
195
  end_sync<ngpus, true>(sg, self_sg, rank);
196
197
198
}

template <typename P>
199
200
DINLINE P* get_tmp_buf(volatile Signal* sg) {
  return (P*)(((Signal*)sg) + 1);
201
202
203
204
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
205
206
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
                               volatile Signal* self_sg, T* __restrict__ result,
207
208
209
210
211
212
213
214
                               int rank, int size) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
215
  int largest_part = part + size % ngpus;
216
217
  const P* ptrs[ngpus];
  P* tmps[ngpus];
218
219
220
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
221
    ptrs[i] = (const P*)_dp->ptrs[target];
222
223
224
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
225
  start_sync<ngpus>(sg, self_sg, rank);
226
227
228
229
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
230
231
232
233
234
235
236
237
  end_sync<ngpus>(sg, self_sg, rank);

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
  // start + i in the first stage, then thread i also gathers start + i from all
  // ranks.
  for (int idx = tid; idx < largest_part; idx += stride) {
238
239
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
240
241
242
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
243
        ((P*)result)[dst_idx] = tmps[i][idx];
244
      }
245
246
247
248
    }
  }
}

Hanzhi Zhou's avatar
Hanzhi Zhou committed
249
250
251
252
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

253
254
255
256
257
258
259
260
class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
  bool full_nvlink_;

  // below are device pointers
  RankSignals sg_;
261
262
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;
263
264
265

  // stores the registered device pointers from all ranks
  RankData *d_rank_data_base_, *d_rank_data_end_;
266
  std::vector<void*> graph_unreg_buffers_;
Hanzhi Zhou's avatar
Hanzhi Zhou committed
267
  // a map from IPC handles to opened IPC pointers
268
  std::map<IPC_KEY, char*> ipc_handles_;
269
270
271
272

  /**
   * meta is a pointer to device metadata and temporary buffer for allreduce.
   *
273
   * There's a total of sizeof(Signal) of prefix before the actual data,
274
275
276
277
278
   * so meta + 1 points to actual temporary buffer.
   *
   * note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor
   */
279
280
281
  CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
                  const cudaIpcMemHandle_t* handles,
                  const std::vector<int64_t>& offsets, int rank,
282
283
284
285
                  bool full_nvlink = true)
      : rank_(rank),
        world_size_(offsets.size()),
        full_nvlink_(full_nvlink),
286
        self_sg_(meta),
287
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
288
289
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
290
      Signal* rank_sg;
291
      if (i != rank_) {
292
        char* handle = open_ipc_handle(&handles[i]);
293
        handle += offsets[i];
294
        rank_sg = (Signal*)handle;
295
      } else {
296
        rank_sg = self_sg_;
297
      }
298
      sg_.signals[i] = rank_sg;
299
300
301
    }
  }

302
  char* open_ipc_handle(const void* ipc_handle) {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
303
    auto [it, new_handle] =
304
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
Hanzhi Zhou's avatar
Hanzhi Zhou committed
305
    if (new_handle) {
306
307
308
      char* ipc_ptr;
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
Hanzhi Zhou's avatar
Hanzhi Zhou committed
309
310
311
312
313
314
                                     cudaIpcMemLazyEnablePeerAccess));
      it->second = ipc_ptr;
    }
    return it->second;
  }

315
316
317
318
319
320
321
322
  std::pair<std::vector<uint8_t>, std::vector<int64_t>>
  get_graph_buffer_ipc_meta() {
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
    std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
323
      void* base_ptr;
324
325
326
327
328
329
330
      // note: must share the base address of each allocation, or we get wrong
      // address
      if (cuPointerGetAttribute(&base_ptr,
                                CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(cudaIpcGetMemHandle(
331
332
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
333
334
335
336
337
338
339
340
341
342
343
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
      throw std::runtime_error(
          "Rank data buffer is overflowed by " +
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
  }

344
345
  void register_buffer(const std::vector<std::string>& handles,
                       const std::vector<int64_t>& offsets, void* self) {
346
347
348
349
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
      if (i != rank_) {
350
        char* handle = open_ipc_handle(handles[i].data());
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        handle += offsets[i];
        data.ptrs[i] = handle;
      } else {
        data.ptrs[i] = self;
      }
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
    buffers_[self] = d_data;
  }

  // note: when registering graph buffers, we intentionally choose to not
  // deduplicate the addresses. That means if the allocator reuses some
  // addresses, they will be registered again. This is to account for the remote
  // possibility of different allocation patterns between ranks. For example,
  // rank 1 may get the same input address for the second allreduce, but rank 2
  // got a different address. IPC handles have internal reference counting
  // mechanism so overhead should be small.
  void register_graph_buffers(
371
372
      const std::vector<std::string>& handles,
      const std::vector<std::vector<int64_t>>& offsets) {
373
374
375
376
377
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
378
      auto& rd = rank_data[i];
379
380
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
381
          char* handle =
Hanzhi Zhou's avatar
Hanzhi Zhou committed
382
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
                         sizeof(RankData) * num_buffers,
                         cudaMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

  /**
   * This is the result after careful grid search. Using 36 blocks give the best
   * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
   * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
   * Not quite sure the underlying reason, but my guess is that too many SMs
   * will cause contention on NVLink bus.
   */
  template <typename T>
405
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
406
407
408
409
410
411
412
                 int threads = 512, int block_limit = 36) {
    auto d = packed_t<T>::P::size;
    if (size % d != 0)
      throw std::runtime_error(
          "custom allreduce currently requires input length to be multiple "
          "of " +
          std::to_string(d));
413
414
415
416
    if (block_limit > kMaxBlocks)
      throw std::runtime_error("max supported block limit is " +
                               std::to_string(kMaxBlocks) + ". Got " +
                               std::to_string(block_limit));
417

418
    RankData* ptrs;
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    cudaStreamCaptureStatus status;
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
    if (status == cudaStreamCaptureStatusActive) {
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
      graph_unreg_buffers_.push_back(input);
    } else {
      auto it = buffers_.find(input);
      if (it == buffers_.end())
        throw std::runtime_error(
            "buffer address " +
            std::to_string(reinterpret_cast<uint64_t>(input)) +
            " is not registered!");
      ptrs = it->second;
    }

    size /= d;
    auto bytes = size * sizeof(typename packed_t<T>::P);
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
437
438
439
#define KL(ngpus, name)                                                       \
  name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
                                                 rank_, size);
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
#define REDUCE_CASE(ngpus)                            \
  case ngpus: {                                       \
    if (world_size_ == 2) {                           \
      KL(ngpus, cross_device_reduce_1stage);          \
    } else if (full_nvlink_) {                        \
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);        \
      } else {                                        \
        KL(ngpus, cross_device_reduce_2stage);        \
      }                                               \
    }                                                 \
    break;                                            \
  }

    switch (world_size_) {
      REDUCE_CASE(2)
      REDUCE_CASE(4)
      REDUCE_CASE(6)
      REDUCE_CASE(8)
      default:
        throw std::runtime_error(
            "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
            "gpus = " +
            std::to_string(world_size_));
    }
#undef REDUCE_CASE
#undef KL
  }

  ~CustomAllreduce() {
Hanzhi Zhou's avatar
Hanzhi Zhou committed
471
    for (auto [_, ptr] : ipc_handles_) {
472
473
474
475
476
477
478
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
    }
  }
};
/**
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 a template instantiation:
479
480
 * template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
 half *, int, int, int);
481
482
*/
}  // namespace vllm