Unverified Commit b02da24a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Refactor sgl-kernel build (#2642)

parent bdd2827a
......@@ -25,46 +25,29 @@ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")
find_package(Torch REQUIRED)
# Warp Reduce library
add_library(warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
add_library(_kernels SHARED
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)
target_include_directories(warp_reduce
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)
target_link_libraries(warp_reduce
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)
# TRT Reduce library
add_library(trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)
target_include_directories(trt_reduce
target_include_directories(_kernels
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)
target_link_libraries(trt_reduce
target_link_libraries(_kernels
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)
# Set common properties for both libraries
foreach(target warp_reduce trt_reduce)
foreach(target _kernels)
set_target_properties(${target} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
......
......@@ -58,78 +58,45 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)
nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=[
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
),
]
setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
......
from .ops import moe_align_block_size
from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
init_custom_reduce,
moe_align_block_size,
warp_reduce,
)
__all__ = [
"moe_align_block_size",
"warp_reduce",
"init_custom_reduce",
"custom_dispose",
"custom_reduce",
]
......@@ -3,11 +3,11 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "utils.hpp"
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
......@@ -133,7 +133,3 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
token_cnts_buffer.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>());
});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
}
#include <torch/extension.h>
#include "utils.hpp"
// warp_reduce
torch::Tensor warp_reduce_cuda(torch::Tensor input);
torch::Tensor warp_reduce(torch::Tensor input) {
CHECK_CUDA_INPUT(input);
return warp_reduce_cuda(input);
}
// trt_reduce
using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
// moe_align_block_size
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// warp_reduce
m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)");
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
// moe_align_block_size
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
}
#include <torch/extension.h>
#include "utils.hpp"
torch::Tensor warp_reduce_cuda(torch::Tensor input);
torch::Tensor warp_reduce(torch::Tensor input) {
CHECK_CUDA_INPUT(input);
return warp_reduce_cuda(input);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)");
}
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "utils.hpp"
#define FINAL_MASK 0xffffffff
#define BLOCK_SIZE 256
......
from .moe_align_block_size import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import reduce as _reduce
def warp_reduce(input_tensor):
return _reduce(input_tensor)
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out)
def custom_dispose(fa):
_dispose(fa)
def custom_reduce(fa, inp, out):
_all_reduce(fa, inp, out)
def moe_align_block_size(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment