trt_reduce_internal.cu 21.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>

#include "trt_reduce_internal.cuh"
44
#include "utils.h"
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
  asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
  uint32_t flag;
  asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
  return flag;
}

yizhang2077's avatar
yizhang2077 committed
60
61
62
63
64
65
66
67
68
69
static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
  asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
  uint32_t flag;
  asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
  return flag;
}

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////

// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
  int4 packed;
  float unpacked[4];
};

using PackedHalf = union {
  int4 packed;
  half2 unpacked[4];
};

template <typename T>
struct PackedOn16Bytes {};

template <>
struct PackedOn16Bytes<float> {
  using Type = PackedFloat;
};

template <>
struct PackedOn16Bytes<half> {
  using Type = PackedHalf;
};

#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using PackedBFloat16 = union {
  int4 packed;
  __nv_bfloat162 unpacked[4];
};

template <>
struct PackedOn16Bytes<__nv_bfloat16> {
  using Type = PackedBFloat16;
};
#endif

// add two 128b data
template <typename T>
inline __device__ int4 add128b(T& a, T& b) {
  T c;
  c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
  c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
  c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
  c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
  return c.packed;
}

121
122
123
124
125
126
127
__inline__ __device__ void multi_gpu_barrier(
    uint32_t** signals,
    uint32_t const flag,
    size_t const local_rank,
    size_t const world_size,
    int const tidx,
    int const bidx) {
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  // After this function, at least one block in each GPU has reached the barrier
  if (tidx < world_size) {
    // we can think of signals having the shape [world_size, world_size]
    // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension

    // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
    size_t offset = (flag % 2) ? world_size : 0;

    if (bidx == 0) {
      st_flag_release(flag, signals[tidx] + offset + local_rank);
    }

    // All blocks check that corresponding block 0 on other GPUs have set the flag
    // No deadlock because block #0 is always the first block started
    uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
    while (ld_flag_acquire(peer_barrier_d) != flag) {
    }
  }

  __syncthreads();
}

150
template <bool start, bool need_fence = false>
151
152
153
154
155
156
157
158
__inline__ __device__ void block_barrier(
    uint32_t** signals,
    uint32_t const flag,
    size_t const local_rank,
    size_t const world_size,
    int const tidx,
    int const bidx,
    int const grid_size) {
159
  if constexpr (!start) {
yizhang2077's avatar
yizhang2077 committed
160
161
162
163
164
165
166
167
168
169
170
    __syncthreads();
  }
  // After this function, the block of id == bidx of each GPU has reached the barrier
  if (tidx < world_size) {
    // we can think of signals having the shape [world_size, 2, num_blocks, world_size]
    // (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
    // Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension

    // Block broadcast its flag (local_rank on emitting dimension) to all receivers
    uint32_t flag_block_offset = world_size + bidx * world_size;

171
    flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
yizhang2077's avatar
yizhang2077 committed
172
173

    uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
174
175
176
    // Blocks check that corresponding blocks on other GPUs have also set the flag
    if constexpr (need_fence) {
      st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
yizhang2077's avatar
yizhang2077 committed
177
178
179
      while (ld_flag_acquire(peer_barrier_d) != flag) {
      }
    } else {
180
      st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
yizhang2077's avatar
yizhang2077 committed
181
182
183
184
185
186
187
188
      while (ld_flag_volatile(peer_barrier_d) != flag) {
      }
    }
  }

  __syncthreads();
}

189
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
190
static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) {
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  // Suppose that two GPUs participate in the AR exchange, and we start four blocks.
  // The message is partitioned into chunks as detailed below:
  //               message
  //       |-------------------|
  // GPU 0 | B0 | B1 | B2 | B3 |
  // GPU 1 | B0 | B1 | B2 | B3 |
  //
  // Here the step-by-step behavior of one block:
  // 1. B0 copies the chunk it  is responsible for, from local_input to shareable buffer
  // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
  // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
  //
  // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
  // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
  //
  // With PUSH_MODE, we consider that the shared buffer is of size:
  // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
  //
  // Here the step-by-step behavior of one block:
  // 1. B0 push the chunk is it responsible for into all other GPUs:
  //    params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
  // 2. block sync so the block is shared by other GPUs
  // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]

  int const bidx = blockIdx.x;
  int const tidx = threadIdx.x;
217
  int const grid_size = gridDim.x;
218
219
220
221
222
223
224
225

  // The number of elements packed into one for comms
  static constexpr int NUM_ELTS = 16 / sizeof(T);

  // Packed data type for comms
  using PackedStruct = typename PackedOn16Bytes<T>::Type;

  // The source pointers. Distributed round-robin for the different warps.
226
227
  auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
  T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
228
229
230
231
  // Start and end offsets of the thread
  size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
  size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);

232
233
234
235
236
237
238
239
240
  if constexpr (COPY_INPUT) {
    T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
    // Copy from local buffer to shareable buffer
    for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
      *reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
          *reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
    }
  }
  // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
241
242
  block_barrier<true>(
      params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
243
244
245
246
247
248
249

  // Each block accumulates the values from the different GPUs on the same node.
  for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
    // Iterate over the different ranks/devices on the node to load the values.
    PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
    for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
250
      vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
251
252
253
254
255
256
257
258
    }

    // Sum the values from the different ranks.
    PackedStruct sums;
    sums.packed = {0, 0, 0, 0};
#pragma unroll
    for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
      // Always reduce from rank 0 to ensure stable reduce order.
259
      sums.packed = add128b(sums, vals[rank]);
260
261
262
263
264
265
266
    }

    // Store to the destination buffer.
    *reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
  }
}

267
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
yizhang2077's avatar
yizhang2077 committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
  // Suppose that two GPUs participate in the AR exchange, and we start two blocks.
  // The message is partitioned into chunks as detailed below:
  //               message
  //       |-------------------|
  //       |--GPU 0--|--GPU 1--| (GPU responsibility parts)
  // GPU 0 | B0 | B1 | B0 | B1 |
  // GPU 1 | B0 | B1 | B0 | B1 |
  //
  // Here the step-by-step behavior of one block:
  // 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
  // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
  // 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
  //    part (the first half of the message, see GPU responsibility row above)
  // 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
  //       where GPU 1 is responsible: the second half of the message.
  // 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
  // 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
  //    For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
  //
  // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
  // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
  // to be read.
  //
  // Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
  // However, it's only responsible for the summation of a single chunk.
  //
  // With PUSH_MODE, we consider that the shared buffer is of size:
  // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
  //
  // Here the step-by-step behavior of one block:
  // 1. B0 push the chunks is it responsible for into the corresponding GPUs:
  //    params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
  // 2. block sync so the blocks have been shared by other GPUs
  // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
  // 4. block barrier (corresponding blocks have finished reduction)
  // 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
  //    written at index 0 of 2nd dim)

  int const bidx = blockIdx.x;
  int const tidx = threadIdx.x;
  int const grid_size = gridDim.x;

  // The number of elements packed into one for comms
  static constexpr int PACKED_ELTS = 16 / sizeof(T);
  using PackedType = typename PackedOn16Bytes<T>::Type;

315
316
317
  T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
  auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
  T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
yizhang2077's avatar
yizhang2077 committed
318
319
320
321
322
323
  T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);

  size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
  size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);

  T* buffers[RANKS_PER_NODE];
324
  T* buffers_unorder[RANKS_PER_NODE];
yizhang2077's avatar
yizhang2077 committed
325
326
327
328
329
330
  int ranks[RANKS_PER_NODE];
#pragma unroll
  for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
    // A mapping of the ranks to scatter reads as much as possible
    int rank = (params.local_rank + ii) % RANKS_PER_NODE;
    ranks[ii] = rank;
331
332
    buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
    buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
yizhang2077's avatar
yizhang2077 committed
333
334
  }

335
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
yizhang2077's avatar
yizhang2077 committed
336
337
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  cudaGridDependencySynchronize();
338
#endif
yizhang2077's avatar
yizhang2077 committed
339
340
#endif

341
342
343
344
345
346
347
348
349
350
351
352
353
354
  if constexpr (COPY_INPUT) {
    // Copy all blocks from local buffer to shareable buffer
    for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
      for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
        size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
        if (offset_rank >= params.elts_total) {
          continue;
        }
        *reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
            *reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
      }
    }
  }
355
356
  block_barrier<true>(
      params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
yizhang2077's avatar
yizhang2077 committed
357
358
359
360
361
362
363
364
365

  // Each block accumulates the values from the different GPUs on the same node.
  for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
    size_t const responsible_block_offset = local_offset + params.rank_offset;

    // Iterate over the different ranks/devices on the node to load the values.
    PackedType vals[RANKS_PER_NODE];
#pragma unroll
    for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
366
      vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
yizhang2077's avatar
yizhang2077 committed
367
368
369
370
371
372
373
374
    }

    // Sum the values from the different ranks.
    PackedType sums;
    sums.packed = {0, 0, 0, 0};
#pragma unroll
    for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
      // Always reduce from rank 0 to ensure stable reduce order.
375
      sums.packed = add128b(sums, vals[rank]);
yizhang2077's avatar
yizhang2077 committed
376
377
    }

378
379
380
381
382
383
    // Store to the local buffer or tmp buffer
    if constexpr (COPY_INPUT) {
      *reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
    } else {
      *reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
    }
yizhang2077's avatar
yizhang2077 committed
384
385
  }

386
387
  block_barrier<false, true>(
      params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
yizhang2077's avatar
yizhang2077 committed
388
389
390
391
392
393
394
395
396
397

  // Gather all needed elts from other intra-node ranks
  for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
    for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
      // use round-robin gathering from other ranks
      size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
      if (offset_rank >= params.elts_total) {
        continue;
      }
398
399
400
401
402
403
404
      if constexpr (COPY_INPUT) {
        *reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
            *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
      } else {
        *reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
            *reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
      }
yizhang2077's avatar
yizhang2077 committed
405
406
    }
  }
407
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
yizhang2077's avatar
yizhang2077 committed
408
409
410
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
  cudaTriggerProgrammaticLaunchCompletion();
#endif
411
#endif
yizhang2077's avatar
yizhang2077 committed
412
413
}

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
////////////////////////////////////////////////////////////////////////////////////////////////////

inline int divUp(int a, int b) {
  return (a + b - 1) / b;
}

inline int roundUp(int a, int n) {
  return divUp(a, n) * n;
}

std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
  int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
  switch (algo) {
    case AllReduceStrategyType::ONESHOT: {
      assert(params.elts_total % elts_per_thread == 0);
      size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
      threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
      blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
      params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
      params.elts_per_rank = params.elts_total;
      break;
    }
yizhang2077's avatar
yizhang2077 committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    case AllReduceStrategyType::TWOSHOT: {
      assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
      size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);

      /*
      threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
      blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
      */
      while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
        blocks_per_grid += 1;
      }

      threads_per_block = total_threads / blocks_per_grid;

      // NOTE: need to adjust here
      if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
        size_t iter_factor = 1;
        while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
          iter_factor += 1;
        }
        blocks_per_grid /= iter_factor;
      }
      params.elts_per_rank = params.elts_total / params.ranks_per_node;
      params.rank_offset = params.local_rank * params.elts_per_rank;
      params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
      break;
    }
463
464
465
466
467
468
469
470
471
    default:
      assert(false && "Algorithm not supported here.");
  }

  return std::make_tuple(blocks_per_grid, threads_per_block);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

472
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
473
474
475
476
477
478
void dispatchARKernels(
    AllReduceStrategyType algo,
    AllReduceParams& param,
    int blocks_per_grid,
    int threads_per_block,
    cudaStream_t stream) {
yizhang2077's avatar
yizhang2077 committed
479
480
  switch (algo) {
    case AllReduceStrategyType::ONESHOT: {
481
      oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
yizhang2077's avatar
yizhang2077 committed
482
483
484
      break;
    }
    case AllReduceStrategyType::TWOSHOT: {
485
      twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
yizhang2077's avatar
yizhang2077 committed
486
487
488
      break;
    }
  }
489
490
}

491
492
template <typename T, bool COPY_INPUT>
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
493
494
495
496
  size_t elts_per_thread = 16 / sizeof(T);
  auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
  switch (param.ranks_per_node) {
    case 2:
497
      dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
498
499
      break;
    case 4:
500
      dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
501
502
      break;
    case 6:
503
      dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
504
505
      break;
    case 8:
506
      dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
507
508
509
510
      break;
    default:
      break;
  }
511
512
513
514
515
516
517
518
519
}

template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
  if (param.is_capturing) {
    dispatchARKernelsCopyInput<T, false>(strat, param, stream);
  } else {
    dispatchARKernelsCopyInput<T, true>(strat, param, stream);
  }
520
521
522
  CHECK_CUDA_SUCCESS(cudaGetLastError());
}

523
524
void trtCustomAllReduce(
    AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
  if (params.elts_total == 0) {
    return;
  }

  switch (data_type) {
    case at::ScalarType::Float:
      invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
      break;
    case at::ScalarType::Half:
      invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
      break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16:
      invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
      break;
#endif
    default:
      assert(false && "Unsupported data type");
  }
}
}  // namespace trt_llm