Unverified Commit bcbbf519 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)

parent 0d99adb7
...@@ -157,8 +157,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE ...@@ -157,8 +157,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu" "csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu" "csrc/elementwise/activation.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu"
......
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "custom_all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new vllm::CustomAllreduce(
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(
stream, reinterpret_cast<float*>(reg_buffer), reinterpret_cast<float*>(out.data_ptr()), out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(
stream, reinterpret_cast<half*>(reg_buffer), reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream,
reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
out.numel());
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
}
int64_t meta_size() {
return sizeof(vllm::Signal);
}
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <array>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#include "utils.h"
namespace vllm {
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) {
return __half2float(val);
}
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) {
return a += b;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) {
return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val)
;
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val)
;
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(
Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CHECK_CUDA_SUCCESS(
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include "trt_reduce_internal.cuh"
#include "utils.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
int4 packed;
float unpacked[4];
};
using PackedHalf = union {
int4 packed;
half2 unpacked[4];
};
template <typename T>
struct PackedOn16Bytes {};
template <>
struct PackedOn16Bytes<float> {
using Type = PackedFloat;
};
template <>
struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
};
template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
#endif
// add two 128b data
template <typename T>
inline __device__ int4 add128b(T& a, T& b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;
if (bidx == 0) {
st_flag_release(flag, signals[tidx] + offset + local_rank);
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
}
__syncthreads();
}
template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx,
int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
// Blocks check that corresponding blocks on other GPUs have also set the flag
if constexpr (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
}
if constexpr (start || need_fence) {
__syncthreads();
}
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedStruct = typename PackedOn16Bytes<T>::Type;
// The source pointers. Distributed round-robin for the different warps.
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
if constexpr (COPY_INPUT) {
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
// Copy from local buffer to shareable buffer
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
}
// Sum the values from the different ranks.
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
sums.packed = add128b(sums, vals[rank]);
}
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
block_barrier<false>(
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;
// The number of elements packed into one for comms
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
T* buffers[RANKS_PER_NODE];
T* buffers_unorder[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
#endif
if constexpr (COPY_INPUT) {
// Copy all blocks from local buffer to shareable buffer
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
}
}
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
size_t const responsible_block_offset = local_offset + params.rank_offset;
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
}
// Sum the values from the different ranks.
PackedType sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
sums.packed = add128b(sums, vals[rank]);
}
// Store to the local buffer or tmp buffer
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
} else {
*reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
}
}
block_barrier<false, true>(
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// use round-robin gathering from other ranks
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
} else {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int divUp(int a, int b) {
return (a + b - 1) / b;
}
inline int roundUp(int a, int n) {
return divUp(a, n) * n;
}
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
assert(params.elts_total % elts_per_thread == 0);
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
params.elts_per_rank = params.elts_total;
break;
}
case AllReduceStrategyType::TWOSHOT: {
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_rank = params.elts_total / params.ranks_per_node;
params.rank_offset = params.local_rank * params.elts_per_rank;
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
break;
}
default:
assert(false && "Algorithm not supported here.");
}
return std::make_tuple(blocks_per_grid, threads_per_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(
AllReduceStrategyType algo,
AllReduceParams& param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
case AllReduceStrategyType::TWOSHOT: {
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
}
}
template <typename T, bool COPY_INPUT>
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
if (param.is_capturing) {
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
} else {
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}
switch (data_type) {
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
#endif
default:
assert(false && "Unsupported data type");
}
}
} // namespace trt_llm
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include "trt_reduce_internal.cuh"
#include "utils.h"
using namespace trt_llm;
using fptr_t = int64_t;
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta {
public:
AllReduceMeta(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->barrier_in = barrier_in;
this->barrier_out = barrier_out;
this->tmp_result_buffers = tmp_result_buffers;
this->rank_data_base = reinterpret_cast<RankData*>(rank_data.data_ptr());
RankData data;
for (int i = 0; i < world_size; i++) {
data.ptrs[i] = (void*)buffers[i];
}
auto d_data = this->rank_data_base++;
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
this->buffers = d_data;
}
~AllReduceMeta() {
for (auto [_, ptr] : ipc_handles_) {
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
}
}
public:
int world_size;
int rank_id;
std::vector<fptr_t> barrier_in;
std::vector<fptr_t> barrier_out;
std::vector<fptr_t> tmp_result_buffers;
int barrier_flag = 1;
RankData* buffers;
RankData* rank_data_base;
std::vector<void*> graph_unreg_buffers;
std::map<IPC_KEY, char*> ipc_handles_;
};
// Get the number of bits for a given data type.
inline int get_bits(at::ScalarType dtype) {
switch (dtype) {
case at::ScalarType::Float:
return 32;
case at::ScalarType::Half:
case at::ScalarType::BFloat16:
return 16;
default:
assert(false && "Unsupported data type");
}
}
// Check if customized all-reduce kernels can be applied.
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
// The customized all-reduce kernel has the following requirement(s).
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
fptr_t init_custom_ar(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
delete fa;
}
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto num_buffers = m->graph_unreg_buffers.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = m->graph_unreg_buffers[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) {
assert(false && "failed to get pointer attr");
}
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
std::vector<int64_t> bytes(handles.begin(), handles.end());
return std::make_pair(bytes, offsets);
}
char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
handle_bytes.emplace_back(handles[i].begin(), handles[i].end());
}
auto num_buffers = m->graph_unreg_buffers.size();
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = m->graph_unreg_buffers[i];
auto& rd = rank_data[i];
for (int j = 0; j < m->world_size; j++) {
if (j != m->rank_id) {
char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CHECK_CUDA_SUCCESS(
cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
m->rank_data_base += num_buffers;
m->graph_unreg_buffers.clear();
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto stream = c10::cuda::getCurrentCUDAStream().stream();
auto num_elements = inp.numel();
auto dtype = inp.scalar_type();
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
// should be gurantee in python code
assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT);
assert(CanApplyCustomAllReduce(num_elements, dtype));
// Initialize the all-reduce kernel arguments.
int world_size = m->world_size;
AllReduceParams params;
params.ranks_per_node = world_size;
params.rank = m->rank_id;
params.local_rank = m->rank_id;
params.local_input_buffer_ptr = inp.data_ptr();
params.local_output_buffer_ptr = out.data_ptr();
params.elts_total = inp.numel();
params.elts_size = inp.element_size();
params.barrier_flag = ++(m->barrier_flag);
cudaStreamCaptureStatus status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
params.is_capturing = (status == cudaStreamCaptureStatusActive);
if (params.is_capturing) {
params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size();
m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr);
} else {
params.peer_comm_buffer_ptrs = m->buffers;
}
for (int i = 0; i < world_size; ++i) {
params.tmp_result_buffers[i] = reinterpret_cast<uint32_t*>(m->tmp_result_buffers[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
}
auto data_type = out.scalar_type();
trtCustomAllReduce(params, data_type, strategy, stream);
}
...@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers); m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose); m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def("register_buffer", &register_buffer);
m.def( m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " "init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"barrier_in, int[] barrier_out) -> int"); "int rank, bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); m.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce); m.impl("all_reduce", torch::kCUDA, &all_reduce);
/* /*
* From csrc/attention * From csrc/attention
*/ */
......
...@@ -21,6 +21,7 @@ limitations under the License. ...@@ -21,6 +21,7 @@ limitations under the License.
#include <torch/library.h> #include <torch/library.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <tuple>
#include <vector> #include <vector>
#define _CONCAT(A, B) A##B #define _CONCAT(A, B) A##B
...@@ -63,18 +64,14 @@ void register_graph_buffers( ...@@ -63,18 +64,14 @@ void register_graph_buffers(
torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else #else
// TRTLLM custom allreduce // custom allreduce
fptr_t init_custom_ar( fptr_t
int64_t rank_id, init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); int64_t meta_size();
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers( void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets); fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
......
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
enum class AllReduceStrategyType : int8_t {
RING = 0,
ONESHOT = 1,
TWOSHOT = 2,
AUTO = 3,
};
struct RankData {
void* ptrs[MAX_RANKS_PER_NODE];
};
struct AllReduceParams {
size_t elts_size;
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, rank, local_rank;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
RankData* peer_comm_buffer_ptrs;
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
bool is_capturing;
};
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
if (world_size <= 2) {
return 16 * 1024 * 1024;
}
return 8 * 1024 * 1024;
}
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
if (message_size > maxWorkspaceSize) {
assert(false && "Custom allreduce do not ring currently");
return AllReduceStrategyType::RING;
}
if (world_size <= 2) {
return AllReduceStrategyType::ONESHOT;
}
if (world_size <= 4) {
if (message_size < 1 * 1024 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
return AllReduceStrategyType::TWOSHOT;
}
if (message_size < 512 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
return AllReduceStrategyType::TWOSHOT;
}
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
} // namespace trt_llm
...@@ -50,28 +50,38 @@ if torch.version.hip is not None: ...@@ -50,28 +50,38 @@ if torch.version.hip is not None:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
else: else:
# TRTLLM custom allreduce
def init_custom_reduce( def init_custom_ar(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
): ) -> int:
return torch.ops.sgl_kernel.init_custom_ar.default( return torch.ops.sgl_kernel.init_custom_ar.default(
rank_id, ipc_tensors, rank_data, rank, full_nvlink
num_devices,
rank_data,
buffers,
tmp_buffers,
barrier_in,
barrier_out,
) )
def custom_dispose(fa): def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose.default(fa) torch.ops.sgl_kernel.dispose.default(fa)
def custom_reduce(fa, inp, out): def all_reduce(
torch.ops.sgl_kernel.all_reduce.default(fa, inp, out) fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops.sgl_kernel.all_reduce.default(
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
)
def get_graph_buffer_ipc_meta(fa): def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa) return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_graph_buffers(fa, handles, offsets): def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets) torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
...@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra ...@@ -16,7 +16,6 @@ from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibra
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
ranks = list(range(world_size))
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group( dist.init_process_group(
backend="nccl", backend="nccl",
...@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes) ...@@ -26,39 +25,18 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
) )
group = dist.group.WORLD group = dist.group.WORLD
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
buffer_ptrs = None
tmp_result_buffer_ptrs = None
barrier_in_ptrs = None
barrier_out_ptrs = None
custom_ptr = None
try: try:
buffer_ptrs = TestCustomAllReduce.create_shared_buffer( device = torch.device(f"cuda:{rank}")
buffer_max_size, group=group max_size = 8192 * 1024
) meta_ptrs = TestCustomAllReduce.create_shared_buffer(
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer( custom_ops.meta_size() + max_size, group=group
buffer_max_size, group=group
)
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
) )
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device) rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
custom_ptr = custom_ops.init_custom_reduce( custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
rank, custom_ops.register_buffer(custom_ptr, buffer_ptrs)
world_size,
rank_data,
buffer_ptrs,
tmp_result_buffer_ptrs,
barrier_in_ptrs,
barrier_out_ptrs,
)
test_loop = 10 test_loop = 10
for sz in test_sizes: for sz in test_sizes:
...@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes) ...@@ -68,7 +46,9 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
inp1_ref = inp1.clone() inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1) out1 = torch.empty_like(inp1)
custom_ops.custom_reduce(custom_ptr, inp1, out1) custom_ops.all_reduce(
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
)
dist.all_reduce(inp1_ref, group=group) dist.all_reduce(inp1_ref, group=group)
...@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes) ...@@ -77,15 +57,11 @@ def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes)
finally: finally:
dist.barrier(group=group) dist.barrier(group=group)
if custom_ptr is not None: if custom_ptr is not None:
custom_ops.custom_dispose(custom_ptr) custom_ops.dispose(custom_ptr)
if buffer_ptrs: if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group) TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if tmp_result_buffer_ptrs: if meta_ptrs:
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group) TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
if barrier_in_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
if barrier_out_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
dist.destroy_process_group(group=group) dist.destroy_process_group(group=group)
...@@ -122,7 +98,18 @@ def multi_process_parallel( ...@@ -122,7 +98,18 @@ def multi_process_parallel(
class TestCustomAllReduce(unittest.TestCase): class TestCustomAllReduce(unittest.TestCase):
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
1048576,
2097152,
]
world_sizes = [2, 4, 8] world_sizes = [2, 4, 8]
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment