custom_all_reduce_hip.cuh 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// !!! This is a file automatically generated by hipify!!!
#pragma once

#include <hip/hip_runtime.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#else
#include <hip/hip_bf16.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>

#define CUDACHECK(cmd)                                                                     \
  do {                                                                                     \
    hipError_t e = cmd;                                                                    \
    if (e != hipSuccess) {                                                                 \
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
      exit(EXIT_FAILURE);                                                                  \
    }                                                                                      \
  } while (0)

29
namespace sglang {
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal {
  alignas(128) uint32_t start[kMaxBlocks][8];
  alignas(128) uint32_t end[kMaxBlocks][8];
  alignas(128) uint32_t _flag[kMaxBlocks];  // incremental flags for each rank
};

#ifdef USE_ROCM
struct __align__(16) RankData {
  const void* ptrs[8];
};
#else
struct __align__(16) RankData {
  const void* __restrict__ ptrs[8];
};
#endif

struct __align__(16) RankSignals {
#ifndef USE_ROCM
  volatile
#endif
      Signal* signals[8];
};

// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
  T data[sz];
  using type = T;
  static constexpr int size = sz;
};

// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
  // the (P)acked type for load/store
  using P = array_t<T, 16 / sizeof(T)>;
  // the (A)ccumulator type for reduction
  using A = array_t<float, 16 / sizeof(T)>;
};

#define DINLINE __device__ __forceinline__

// scalar cast functions
DINLINE float upcast_s(half val) {
  return __half2float(val);
}

template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
  return __float2half(val);
}

// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
  a = __hadd(a, b);
  return a;
}
DINLINE float& assign_add(float& a, float b) {
  return a += b;
}

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) {
  return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
  return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
  a = __hadd(a, b);
  return a;
}
#endif

template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
  for (int i = 0; i < N; i++) {
    assign_add(a.data[i], b.data[i]);
  }
  return a;
}

template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
  if constexpr (std::is_same<T, float>::value) {
    return val;
  } else {
    array_t<float, N> out;
#pragma unroll
    for (int i = 0; i < N; i++) {
      out.data[i] = upcast_s(val.data[i]);
    }
    return out;
  }
}

template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
  if constexpr (std::is_same<typename O::type, float>::value) {
    return val;
  } else {
    O out;
#pragma unroll
    for (int i = 0; i < O::size; i++) {
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
    }
    return out;
  }
}

// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
156
157
DINLINE void start_sync(
    const RankSignals& sg,
158
#ifndef USE_ROCM
159
    volatile
160
#endif
161
162
    Signal* self_sg,
    int rank) {
163
164
165
166
167
#ifdef USE_ROCM
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
168
169
    __scoped_atomic_store_n(
        &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    // wait until we got true from all ranks
    while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
           flag)
      ;
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
  if (threadIdx.x < ngpus) {
    // reset flag for next time
    self_sg->end[blockIdx.x][threadIdx.x] = 0;
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
    // wait until we got true from all ranks
    while (!self_sg->start[blockIdx.x][threadIdx.x])
      ;
  }
  __syncthreads();
#endif
}

// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
197
198
DINLINE void end_sync(
    const RankSignals& sg,
199
#ifndef USE_ROCM
200
    volatile
201
#endif
202
203
    Signal* self_sg,
    int rank) {
204
205
206
207
208
209
210
211
212
213
#ifdef USE_ROCM
  __syncthreads();
  // eliminate the case that prior writes are not visible after signals become
  // visible. Note that I did not managed to make this happen through a lot of
  // testing. Might be the case that hardware provides stronger guarantee than
  // the memory model.
  uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
  if (threadIdx.x < ngpus) {
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
214
215
216
217
218
    __scoped_atomic_store_n(
        &sg.signals[threadIdx.x]->end[blockIdx.x][rank],
        flag,
        final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
        __MEMORY_SCOPE_SYSTEM);
219
    // wait until we got true from all ranks
220
221
222
223
    while (__scoped_atomic_load_n(
               &self_sg->end[blockIdx.x][threadIdx.x],
               final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
               __MEMORY_SCOPE_DEVICE) < flag)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
      ;
  }
  __syncthreads();
  // use one thread to update flag
  if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
  __syncthreads();
  // eliminate the case that prior writes are not visible after signals become
  // visible. Note that I did not managed to make this happen through a lot of
  // testing. Might be the case that hardware provides stronger guarantee than
  // the memory model.
  if constexpr (!final_sync) __threadfence_system();
  if (threadIdx.x < ngpus) {
    // reset flag for next time
    self_sg->start[blockIdx.x][threadIdx.x] = 0;
    // simultaneously write to the corresponding flag of all ranks.
    // Latency = 1 p2p write
    sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
    // wait until we got true from all ranks
    while (!self_sg->end[blockIdx.x][threadIdx.x])
      ;
  }
  if constexpr (!final_sync) __syncthreads();
#endif
}

template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
  A tmp = upcast(ptrs[0][idx]);
#pragma unroll
  for (int i = 1; i < ngpus; i++) {
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
  }
  return downcast<P>(tmp);
}

template <typename T, int ngpus>
261
262
263
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
    RankData* _dp,
    RankSignals sg,
264
#ifndef USE_ROCM
265
    volatile
266
#endif
267
268
269
270
    Signal* self_sg,
    T* __restrict__ result,
    int rank,
    int size) {
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  // note: we don't reorder the address so the accumulation order is the same
  // for all ranks, ensuring bitwise identical results
  auto dp = *_dp;
  start_sync<ngpus>(sg, self_sg, rank);
  // do the actual reduction
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
  }
  end_sync<ngpus, true>(sg, self_sg, rank);
}

template <typename P>
#ifdef USE_ROCM
DINLINE P* get_tmp_buf(Signal* sg) {
#else
DINLINE P* get_tmp_buf(volatile Signal* sg) {
#endif
  return (P*)(((Signal*)sg) + 1);
}

template <typename T, int ngpus>
294
295
296
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
    RankData* _dp,
    RankSignals sg,
297
#ifndef USE_ROCM
298
    volatile
299
#endif
300
301
302
303
    Signal* self_sg,
    T* __restrict__ result,
    int rank,
    int size) {
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  using P = typename packed_t<T>::P;
  using A = typename packed_t<T>::A;
  int part = size / ngpus;
  int start = rank * part;
  int end = rank == ngpus - 1 ? size : start + part;
  int largest_part = part + size % ngpus;
  const P* ptrs[ngpus];
  P* tmps[ngpus];
#pragma unroll
  for (int i = 0; i < ngpus; i++) {
    int target = (rank + i) % ngpus;
    ptrs[i] = (const P*)_dp->ptrs[target];
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
  }
  auto tmp_out = tmps[0];
  start_sync<ngpus>(sg, self_sg, rank);
  // stage 1: reduce scatter
  for (int idx = start + tid; idx < end; idx += stride) {
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
  }
  end_sync<ngpus>(sg, self_sg, rank);

  // stage 2: allgather. Note: it's important to match the tid between
  // the two stages, because visibility across devices is only guaranteed
  // between threads that have the same tid. If thread i computes the sum of
  // start + i in the first stage, then thread i also gathers start + i from all
  // ranks.
  for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
    for (int i = 0; i < ngpus; i++) {
      int gather_from_rank = ((rank + i) % ngpus);
      if (gather_from_rank == ngpus - 1 || idx < part) {
        int dst_idx = gather_from_rank * part + idx;
        ((P*)result)[dst_idx] = tmps[i][idx];
      }
    }
  }
}

using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));

class CustomAllreduce {
 public:
  int rank_;
  int world_size_;
  bool full_nvlink_;

  // below are device pointers
  RankSignals sg_;
  std::unordered_map<void*, RankData*> buffers_;
  Signal* self_sg_;

  // stores the registered device pointers from all ranks
  RankData *d_rank_data_base_, *d_rank_data_end_;
  std::vector<void*> graph_unreg_buffers_;
  // a map from IPC handles to opened IPC pointers
  std::map<IPC_KEY, char*> ipc_handles_;

  /**
   * meta is a pointer to device metadata and temporary buffer for allreduce.
   *
   * There's a total of sizeof(Signal) of prefix before the actual data,
   * so meta + 1 points to actual temporary buffer.
   *
   * note: this class does not own any device memory. Any required buffers
   * are passed in from the constructor
   */
375
376
377
378
379
380
381
382
  CustomAllreduce(
      Signal* meta,
      void* rank_data,
      size_t rank_data_sz,
      const hipIpcMemHandle_t* handles,
      const std::vector<int64_t>& offsets,
      int rank,
      bool full_nvlink = true)
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
      : rank_(rank),
        world_size_(offsets.size()),
        full_nvlink_(full_nvlink),
        self_sg_(meta),
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
    for (int i = 0; i < world_size_; i++) {
      Signal* rank_sg;
      if (i != rank_) {
        char* handle = open_ipc_handle(&handles[i]);
        handle += offsets[i];
        rank_sg = (Signal*)handle;
      } else {
        rank_sg = self_sg_;
      }
      sg_.signals[i] = rank_sg;
    }
  }

  char* open_ipc_handle(const void* ipc_handle) {
    auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
    if (new_handle) {
      char* ipc_ptr;
406
407
      CUDACHECK(hipIpcOpenMemHandle(
          (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
      it->second = ipc_ptr;
    }
    return it->second;
  }

  std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
    auto num_buffers = graph_unreg_buffers_.size();
    auto handle_sz = sizeof(hipIpcMemHandle_t);
    std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
    std::vector<int64_t> offsets(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto ptr = graph_unreg_buffers_[i];
      void* base_ptr;
      // note: must share the base address of each allocation, or we get wrong
      // address
423
424
      if (hipPointerGetAttribute(
              &base_ptr,
425
#ifdef USE_ROCM
426
              HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
427
#else
428
              CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
429
#endif
430
              (hipDeviceptr_t)ptr) != hipSuccess)
431
432
433
434
435
436
437
438
439
        throw std::runtime_error("failed to get pointer attr");
      CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
    }
    return std::make_pair(handles, offsets);
  }

  void check_rank_data_capacity(size_t num = 1) {
    if (d_rank_data_base_ + num > d_rank_data_end_)
440
441
      throw std::runtime_error(
          "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
  }

  void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
    check_rank_data_capacity();
    RankData data;
    for (int i = 0; i < world_size_; i++) {
      if (i != rank_) {
        char* handle = open_ipc_handle(handles[i].data());
        handle += offsets[i];
        data.ptrs[i] = handle;
      } else {
        data.ptrs[i] = self;
      }
    }
    auto d_data = d_rank_data_base_++;
    CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
    buffers_[self] = d_data;
  }

  // note: when registering graph buffers, we intentionally choose to not
  // deduplicate the addresses. That means if the allocator reuses some
  // addresses, they will be registered again. This is to account for the remote
  // possibility of different allocation patterns between ranks. For example,
  // rank 1 may get the same input address for the second allreduce, but rank 2
  // got a different address. IPC handles have internal reference counting
  // mechanism so overhead should be small.
468
469
  void
  register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    auto num_buffers = graph_unreg_buffers_.size();
    check_rank_data_capacity(num_buffers);
    std::vector<RankData> rank_data(num_buffers);
    for (int i = 0; i < num_buffers; i++) {
      auto self_ptr = graph_unreg_buffers_[i];
      auto& rd = rank_data[i];
      for (int j = 0; j < world_size_; j++) {
        if (j != rank_) {
          char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
          handle += offsets[j][i];
          rd.ptrs[j] = handle;
        } else {
          rd.ptrs[j] = self_ptr;
        }
      }
    }
    CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
    d_rank_data_base_ += num_buffers;
    graph_unreg_buffers_.clear();
  }

  /**
   * This is the result after careful grid search. Using 36 blocks give the best
   * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
   * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
   * Not quite sure the underlying reason, but my guess is that too many SMs
   * will cause contention on NVLink bus.
   */
  template <typename T>
499
500
501
502
503
  void allreduce(
      hipStream_t stream,
      T* input,
      T* output,
      int size,
504
#ifndef USE_ROCM
505
506
      int threads = 512,
      int block_limit = 36){
507
#else
508
509
      int threads = 512,
      int block_limit = 16) {
510
511
512
513
514
515
516
517
#endif
      auto d = packed_t<T>::P::size;
  if (size % d != 0)
    throw std::runtime_error(
        "custom allreduce currently requires input length to be multiple "
        "of " +
        std::to_string(d));
  if (block_limit > kMaxBlocks)
518
519
    throw std::runtime_error(
        "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
520
521
522
523
524
525
526
527
528
529

  RankData* ptrs;
  hipStreamCaptureStatus status;
  CUDACHECK(hipStreamIsCapturing(stream, &status));
  if (status == hipStreamCaptureStatusActive) {
    ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
    graph_unreg_buffers_.push_back(input);
  } else {
    auto it = buffers_.find(input);
    if (it == buffers_.end())
530
531
      throw std::runtime_error(
          "buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
532
533
534
535
536
537
    ptrs = it->second;
  }

  size /= d;
  auto bytes = size * sizeof(typename packed_t<T>::P);
  int blocks = ::min(block_limit, (size + threads - 1) / threads);
538
539
540
#define KL(ngpus, name) \
  hipLaunchKernelGGL(   \
      (name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
#define REDUCE_CASE(ngpus)                                                                        \
  case ngpus: {                                                                                   \
    if (world_size_ == 2) {                                                                       \
      KL(ngpus, cross_device_reduce_1stage);                                                      \
    } else if (full_nvlink_) {                                                                    \
      if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
        KL(ngpus, cross_device_reduce_1stage);                                                    \
      } else {                                                                                    \
        KL(ngpus, cross_device_reduce_2stage);                                                    \
      }                                                                                           \
    }                                                                                             \
    break;                                                                                        \
  }

  switch (world_size_) {
    REDUCE_CASE(2)
    REDUCE_CASE(4)
    REDUCE_CASE(6)
    REDUCE_CASE(8)
    default:
      throw std::runtime_error(
          "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
          "gpus = " +
          std::to_string(world_size_));
  }
#undef REDUCE_CASE
#undef KL
}

~CustomAllreduce() {
  for (auto [_, ptr] : ipc_handles_) {
    CUDACHECK(hipIpcCloseMemHandle(ptr));
  }
}
575
};  // namespace sglang
576
577
578
/**
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 a template instantiation:
579
 * template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
580
581
 half *, int, int, int);
*/
582
}  // namespace sglang