Unverified Commit e04d3f28 authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

adapt tensorrt llm custom all reduce to sgl-kernel (#2481)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 5f2595be
cmake_minimum_required(VERSION 3.18)
project(sgl-kernel LANGUAGES CXX CUDA)
# Basic settings
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
find_package(PythonInterp 3 REQUIRED)
find_package(PythonLibs 3 REQUIRED)
# Set CUDA architectures
set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90")
message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
# Find PyTorch
execute_process(
COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}")
message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}")
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")
find_package(Torch REQUIRED)
include_directories(${PYTHON_INCLUDE_DIRS})
# Warp Reduce library
add_library(warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)
target_include_directories(warp_reduce PRIVATE
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
target_include_directories(warp_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)
target_link_libraries(warp_reduce PRIVATE
${TORCH_LIBRARIES}
${PYTHON_LIBRARIES}
target_link_libraries(warp_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)
set_target_properties(warp_reduce PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
# TRT Reduce library
add_library(trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
)
target_include_directories(trt_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)
target_link_libraries(trt_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)
# Set common properties for both libraries
foreach(target warp_reduce trt_reduce)
set_target_properties(${target} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
PREFIX ""
SUFFIX ".so"
)
endforeach()
......@@ -10,7 +10,7 @@ install:
@pip install -e .
build:
@python3 setup.py bdist_wheel
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel
clean:
@rm -rf build dist *.egg-info
......@@ -19,4 +19,4 @@ test:
@pytest tests/
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.0.2.post4"
version = "0.0.2.post5"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
......
......@@ -84,7 +84,31 @@ setup(
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
)
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
......
from .ops import warp_reduce
from .ops import custom_dispose, custom_reduce, init_custom_reduce, warp_reduce
__all__ = ["warp_reduce"]
__all__ = [
"warp_reduce",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
]
#include <torch/extension.h>
using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
}
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include "trt_reduce_internal.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using PackedFloat = union {
int4 packed;
float unpacked[4];
};
using PackedHalf = union {
int4 packed;
half2 unpacked[4];
};
template <typename T>
struct PackedOn16Bytes {};
template <>
struct PackedOn16Bytes<float> {
using Type = PackedFloat;
};
template <>
struct PackedOn16Bytes<half> {
using Type = PackedHalf;
};
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using PackedBFloat16 = union {
int4 packed;
__nv_bfloat162 unpacked[4];
};
template <>
struct PackedOn16Bytes<__nv_bfloat16> {
using Type = PackedBFloat16;
};
#endif
// add two 128b data
template <typename T>
inline __device__ int4 add128b(T& a, T& b) {
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t offset = (flag % 2) ? world_size : 0;
if (bidx == 0) {
st_flag_release(flag, signals[tidx] + offset + local_rank);
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
}
__syncthreads();
}
template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedStruct = typename PackedOn16Bytes<T>::Type;
// The source pointers. Distributed round-robin for the different warps.
T const* buffers[RANKS_PER_NODE];
// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
}
// Sum the values from the different ranks.
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
}
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int divUp(int a, int b) {
return (a + b - 1) / b;
}
inline int roundUp(int a, int n) {
return divUp(a, n) * n;
}
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) {
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
assert(params.elts_total % elts_per_thread == 0);
size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread);
params.elts_per_rank = params.elts_total;
break;
}
default:
assert(false && "Algorithm not supported here.");
}
return std::make_tuple(blocks_per_grid, threads_per_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECK_CUDA_SUCCESS(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));
assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
CHECK_CUDA_SUCCESS(cudaGetLastError());
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}
switch (data_type) {
case at::ScalarType::Float:
invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
break;
case at::ScalarType::Half:
invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
break;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16:
invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
break;
#endif
default:
assert(false && "Unsupported data type");
}
}
} // namespace trt_llm
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
#include "utils.hpp"
namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
enum class AllReduceStrategyType : int8_t {
RING = 0,
ONESHOT = 1,
TWOSHOT = 2,
AUTO = 3,
};
struct AllReduceParams {
size_t elts_size;
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, rank, local_rank;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
};
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
if (world_size <= 2) {
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);
if (message_size > maxWorkspaceSize) {
assert(false && "Custom allreduce do not ring currently");
return AllReduceStrategyType::RING;
}
if (world_size <= 2) {
return AllReduceStrategyType::ONESHOT;
}
if (world_size <= 4) {
if (message_size < 1 * 1000 * 1000) {
return AllReduceStrategyType::ONESHOT;
}
assert(false && "Custom allreduce do not twoshot currently");
return AllReduceStrategyType::TWOSHOT;
}
if (message_size < 500 * 1000) {
return AllReduceStrategyType::ONESHOT;
}
assert(false && "Custom allreduce do not twoshot currently");
return AllReduceStrategyType::TWOSHOT;
}
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream);
} // namespace trt_llm
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <iostream>
#include <sstream>
#include <unordered_map>
#include "trt_reduce_internal.cuh"
using namespace trt_llm;
using fptr_t = int64_t;
class AllReduceMeta {
public:
AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->buffers = buffers;
this->barrier_in = barrier_in;
this->barrier_out = barrier_out;
}
public:
int world_size;
int rank_id;
std::vector<fptr_t> buffers;
std::vector<fptr_t> barrier_in;
std::vector<fptr_t> barrier_out;
int barrier_flag = 1;
};
// Get the number of bits for a given data type.
inline int get_bits(at::ScalarType dtype) {
switch (dtype) {
case at::ScalarType::Float:
return 32;
case at::ScalarType::Half:
case at::ScalarType::BFloat16:
return 16;
default:
assert(false && "Unsupported data type");
}
}
// Check if customized all-reduce kernels can be applied.
inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) {
// The customized all-reduce kernel has the following requirement(s).
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<AllReduceMeta*>(_fa);
delete fa;
}
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
auto stream = c10::cuda::getCurrentCUDAStream().stream();
auto num_elements = inp.numel();
auto dtype = inp.scalar_type();
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);
// should be gurantee in python code
assert(strategy == AllReduceStrategyType::ONESHOT);
assert(CanApplyCustomAllReduce(num_elements, dtype));
// Initialize the all-reduce kernel arguments.
int world_size = m->world_size;
AllReduceParams params;
params.ranks_per_node = world_size;
params.rank = m->rank_id;
params.local_rank = m->rank_id;
params.local_input_buffer_ptr = inp.data_ptr();
params.local_output_buffer_ptr = out.data_ptr();
params.elts_total = inp.numel();
params.elts_size = inp.element_size();
params.barrier_flag = ++(m->barrier_flag);
for (int i = 0; i < world_size; ++i) {
params.peer_comm_buffer_ptrs[i] = reinterpret_cast<void*>(m->buffers[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(m->barrier_in[i]);
}
for (int i = 0; i < world_size; ++i) {
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(m->barrier_out[i]);
}
auto data_type = out.scalar_type();
trtCustomAllReduce(params, data_type, strategy, stream);
}
#pragma once
#include <torch/extension.h>
#include <sstream>
struct cuda_error : public std::runtime_error {
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error(const char* message) : std::runtime_error(message) {}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error(std::string const& message) : cuda_error{message.c_str()} {}
};
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
std::stringstream _message; \
auto s = cudaGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
#include <torch/extension.h>
torch::Tensor warp_reduce_cuda(torch::Tensor input);
#include "utils.hpp"
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor warp_reduce_cuda(torch::Tensor input);
torch::Tensor warp_reduce(torch::Tensor input) {
CHECK_INPUT(input);
CHECK_CUDA_INPUT(input);
return warp_reduce_cuda(input);
}
......
from .custom_reduce_cuda import all_reduce as _all_reduce
from .custom_reduce_cuda import dispose as _dispose
from .custom_reduce_cuda import init_custom_ar as _init_custom_ar
from .warp_reduce_cuda import reduce as _reduce
def warp_reduce(input_tensor):
return _reduce(input_tensor)
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out)
def custom_dispose(fa):
_dispose(fa)
def custom_reduce(fa, inp, out):
_all_reduce(fa, inp, out)
import ctypes
import logging
import os
import random
import socket
import time
import unittest
from typing import Any, List, Optional, Union
import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
logger = logging.getLogger(__name__)
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
cls: Any,
test_target: Any,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()
class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = {
2: [512, 4096, 32768, 262144, 2097152],
4: [512, 4096, 32768, 131072],
6: [512, 4096, 32768, 65536],
8: [512, 4096, 32768, 65536],
}
cls.world_sizes = [2, 4, 6, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
def test_correctness(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.correctness)
def test_performance(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.performance)
def init_custom_allreduce(self, rank, world_size, group):
import sgl_kernel
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
rank,
world_size,
self.buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
import sgl_kernel
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
import sgl_kernel
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
sgl_kernel.ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank
self.vllm_max_size = 8 * 1024 * 1024
self.vllm_meta_ptrs = self.create_shared_buffer(
vllm_ops.meta_size() + self.vllm_max_size, group=group
)
self.vllm_buffer_ptrs = self.create_shared_buffer(
self.vllm_max_size, group=group
)
self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
)
self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
)
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
def vllm_allreduce(self, inp, out):
vllm_ops.all_reduce(
self.vllm_ptr,
inp,
out,
self.vllm_buffer_ptrs[self.vllm_rank],
self.vllm_max_size,
)
def free_vllm_allreduce(self, group):
vllm_ops.dispose(self.vllm_ptr)
self.free_shared_buffer(self.vllm_meta_ptrs, group)
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = torch.distributed.new_group(ranks, backend="gloo")
return group
# compare result with torch.distributed
@ray.remote(num_gpus=1, max_calls=1)
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes[world_size]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
self.custom_allreduce(inp1, out1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
self.free_custom_allreduce(group)
# compare performance with vllm
@ray.remote(num_gpus=1, max_calls=1)
def performance(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
for sz in self.test_sizes[world_size]:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
test_loop = 5000
start = time.time()
for _ in range(test_loop):
self.custom_allreduce(inp1, out1)
elapse_custom = time.time() - start
start = time.time()
for _ in range(test_loop):
self.vllm_allreduce(inp1, out1)
elapse_vllm = time.time() - start
if rank == 0:
logger.warning(
f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}us,"
f"custom time = {elapse_custom * 1000 / test_loop:.4f}us"
)
self.free_custom_allreduce(group)
self.free_vllm_allreduce(group)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment