Unverified Commit 6cb3974e authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

optimize custom allreduce kernel (#2904)

parent f65c13b5
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.0.2.post12"
version = "0.0.2.post13"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
......
......@@ -40,7 +40,7 @@ nvcc_flags = [
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
......
from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
moe_align_block_size,
register_graph_buffers,
sampling_scaling_penalties,
)
......@@ -14,4 +16,6 @@ __all__ = [
"custom_reduce",
"int8_scaled_mm",
"sampling_scaling_penalties",
"get_graph_buffer_ipc_meta",
"register_graph_buffers",
]
......@@ -2,10 +2,14 @@
// trt_reduce
using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
// moe_align_block_size
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
......@@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
m.def("register_graph_buffers", &register_graph_buffers, "custom all reduce register graph buffers");
// moe_align_block_size
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
// sampling_scaling_penalties
......
......@@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
__syncthreads();
}
template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx, int const grid_size,
bool start = true, bool need_fence = false) {
if (!start) {
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
......@@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;
if (flag % 2 == 1) {
flag_block_offset += (grid_size + 1) * world_size;
}
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
if (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
if (need_fence) {
// Blocks check that corresponding blocks on other GPUs have also set the flag
if constexpr (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
......@@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
__syncthreads();
}
template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
......@@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
......@@ -201,18 +196,23 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
using PackedStruct = typename PackedOn16Bytes<T>::Type;
// The source pointers. Distributed round-robin for the different warps.
T const* buffers[RANKS_PER_NODE];
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
if constexpr (COPY_INPUT) {
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
// Copy from local buffer to shareable buffer
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
......@@ -220,7 +220,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
}
// Sum the values from the different ranks.
......@@ -229,8 +229,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
sums.packed = add128b(sums, vals[rank]);
}
// Store to the destination buffer.
......@@ -238,7 +237,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}
}
template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
......@@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
T* buffers[RANKS_PER_NODE];
T* buffers_unorder[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
......@@ -308,8 +311,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#endif
#endif
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
if constexpr (COPY_INPUT) {
// Copy all blocks from local buffer to shareable buffer
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
}
}
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
......@@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
}
// Sum the values from the different ranks.
......@@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
sums.packed = add128b(sums, vals[rank]);
}
// Store to the local buffer.
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
// Store to the local buffer or tmp buffer
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
} else {
*reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
}
}
block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size, false, true);
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
bidx, grid_size);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
......@@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) = *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
} else {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
......@@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
case AllReduceStrategyType::TWOSHOT: {
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECK_CUDA_SUCCESS(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));
CHECK_CUDA_SUCCESS(cudaGetLastError());
template <typename T, bool COPY_INPUT>
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
if (param.is_capturing) {
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
} else {
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
......
......@@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t {
AUTO = 3,
};
struct RankData {
void* ptrs[MAX_RANKS_PER_NODE];
};
struct AllReduceParams {
size_t elts_size;
size_t elts_total;
......@@ -46,9 +50,11 @@ struct AllReduceParams {
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
RankData* peer_comm_buffer_ptrs;
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
bool is_capturing;
};
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
......
......@@ -12,25 +12,46 @@
using namespace trt_llm;
using fptr_t = int64_t;
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta {
public:
AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->buffers = buffers;
this->barrier_in = barrier_in;
this->barrier_out = barrier_out;
this->tmp_result_buffers = tmp_result_buffers;
this->rank_data_base = reinterpret_cast<RankData*>(rank_data.data_ptr());
RankData data;
for (int i = 0; i < world_size; i++) {
data.ptrs[i] = (void*)buffers[i];
}
auto d_data = this->rank_data_base++;
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
this->buffers = d_data;
}
~AllReduceMeta() {
for (auto [_, ptr] : ipc_handles_) {
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
}
}
public:
int world_size;
int rank_id;
std::vector<fptr_t> buffers;
std::vector<fptr_t> barrier_in;
std::vector<fptr_t> barrier_out;
std::vector<fptr_t> tmp_result_buffers;
int barrier_flag = 1;
RankData* buffers;
RankData* rank_data_base;
std::vector<void*> graph_unreg_buffers;
std::map<IPC_KEY, char*> ipc_handles_;
};
// Get the number of bits for a given data type.
......@@ -52,9 +73,10 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out);
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
......@@ -63,6 +85,75 @@ void dispose(fptr_t _fa) {
delete fa;
}
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto num_buffers = m->graph_unreg_buffers.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = m->graph_unreg_buffers[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) {
assert(false && "failed to get pointer attr");
}
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
std::vector<int64_t> bytes(handles.begin(), handles.end());
return std::make_pair(bytes, offsets);
}
char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
handle_bytes.emplace_back(handles[i].begin(), handles[i].end());
}
auto num_buffers = m->graph_unreg_buffers.size();
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = m->graph_unreg_buffers[i];
auto& rd = rank_data[i];
for (int j = 0; j < m->world_size; j++) {
if (j != m->rank_id) {
char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CHECK_CUDA_SUCCESS(
cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
m->rank_data_base += num_buffers;
m->graph_unreg_buffers.clear();
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto stream = c10::cuda::getCurrentCUDAStream().stream();
......@@ -87,8 +178,18 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
params.elts_size = inp.element_size();
params.barrier_flag = ++(m->barrier_flag);
cudaStreamCaptureStatus status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
params.is_capturing = (status == cudaStreamCaptureStatusActive);
if (params.is_capturing) {
params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size();
m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr);
} else {
params.peer_comm_buffer_ptrs = m->buffers;
}
for (int i = 0; i < world_size; ++i) {
params.peer_comm_buffer_ptrs[i] = reinterpret_cast<void*>(m->buffers[i]);
params.tmp_result_buffers[i] = reinterpret_cast<uint32_t*>(m->tmp_result_buffers[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
......
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import (
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
)
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out)
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return _init_custom_ar(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
)
def custom_dispose(fa):
......@@ -20,6 +28,14 @@ def custom_reduce(fa, inp, out):
_all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa):
return _get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets):
_register_graph_buffers(fa, handles, offsets)
def moe_align_block_size(
topk_ids,
num_experts,
......
......@@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
import ray
import torch
import torch.distributed as dist
from sgl_kernel import ops as custom_ops
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
......@@ -104,35 +105,38 @@ class TestCustomAllReduce(unittest.TestCase):
multi_process_parallel(world_size, self, self.performance)
def init_custom_allreduce(self, rank, world_size, group):
import sgl_kernel
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
buffer_max_size, group=group
)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
)
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
self.custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
self.rank_data,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
import sgl_kernel
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
custom_ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
import sgl_kernel
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
sgl_kernel.ops.custom_dispose(self.custom_ptr)
custom_ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment