Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b02da24a
Unverified
Commit
b02da24a
authored
Dec 30, 2024
by
Ke Bao
Committed by
GitHub
Dec 30, 2024
Browse files
Refactor sgl-kernel build (#2642)
parent
bdd2827a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
108 additions
and
113 deletions
+108
-113
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+6
-23
sgl-kernel/setup.py
sgl-kernel/setup.py
+34
-67
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+11
-1
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+2
-6
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+32
-0
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
+0
-14
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
+2
-1
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+21
-1
No files found.
sgl-kernel/CMakeLists.txt
View file @
b02da24a
...
...
@@ -25,46 +25,29 @@ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")
find_package
(
Torch REQUIRED
)
# Warp Reduce library
add_library
(
warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
add_library
(
_kernels SHARED
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)
target_include_directories
(
warp_reduce
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/sgl-kernel/csrc
${
CUDA_INCLUDE_DIRS
}
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
warp_reduce
PRIVATE
${
TORCH_LIBRARIES
}
Python3::Python
)
# TRT Reduce library
add_library
(
trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)
target_include_directories
(
trt_reduce
target_include_directories
(
_kernels
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/sgl-kernel/csrc
${
CUDA_INCLUDE_DIRS
}
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
trt_reduce
target_link_libraries
(
_kernels
PRIVATE
${
TORCH_LIBRARIES
}
Python3::Python
)
# Set common properties for both libraries
foreach
(
target
warp_reduce trt_reduce
)
foreach
(
target
_kernels
)
set_target_properties
(
${
target
}
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
...
...
sgl-kernel/setup.py
View file @
b02da24a
...
...
@@ -58,78 +58,45 @@ def update_wheel_platform_tag():
old_wheel
.
rename
(
new_wheel
)
nvcc_flags
=
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
]
cxx_flags
=
[
"-O3"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
]
ext_modules
=
[
CUDAExtension
(
name
=
"sgl_kernel.ops._kernels"
,
sources
=
[
"src/sgl-kernel/csrc/warp_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
],
extra_compile_args
=
{
"nvcc"
:
nvcc_flags
,
"cxx"
:
cxx_flags
,
},
libraries
=
libraries
,
extra_link_args
=
extra_link_args
,
),
]
setup
(
name
=
"sgl-kernel"
,
version
=
get_version
(),
packages
=
[
"sgl_kernel"
],
package_dir
=
{
""
:
"src"
},
ext_modules
=
[
CUDAExtension
(
"sgl_kernel.ops.warp_reduce_cuda"
,
[
"src/sgl-kernel/csrc/warp_reduce.cc"
,
"src/sgl-kernel/csrc/warp_reduce_kernel.cu"
,
],
extra_compile_args
=
{
"nvcc"
:
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
],
"cxx"
:
[
"-O3"
],
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
CUDAExtension
(
"sgl_kernel.ops.custom_reduce_cuda"
,
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce.cc"
,
],
extra_compile_args
=
{
"nvcc"
:
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
],
"cxx"
:
[
"-O3"
],
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
CUDAExtension
(
"sgl_kernel.ops.moe_align_block_size"
,
[
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
],
extra_compile_args
=
{
"nvcc"
:
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
],
"cxx"
:
[
"-O3"
],
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
],
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
},
install_requires
=
[
"torch"
],
)
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
b02da24a
from
.ops
import
moe_align_block_size
from
sgl_kernel.ops
import
(
custom_dispose
,
custom_reduce
,
init_custom_reduce
,
moe_align_block_size
,
warp_reduce
,
)
__all__
=
[
"moe_align_block_size"
,
"warp_reduce"
,
"init_custom_reduce"
,
"custom_dispose"
,
"custom_reduce"
,
]
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
b02da24a
...
...
@@ -3,11 +3,11 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh>
#include "utils.hpp"
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
...
...
@@ -133,7 +133,3 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
token_cnts_buffer
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/
trt_reduce
.c
c
→
sgl-kernel/src/sgl-kernel/csrc/
sgl_kernel_ops
.c
u
View file @
b02da24a
#include
<torch/extension.h>
#include
"utils.hpp"
// warp_reduce
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
);
torch
::
Tensor
warp_reduce
(
torch
::
Tensor
input
)
{
CHECK_CUDA_INPUT
(
input
);
return
warp_reduce_cuda
(
input
);
}
// trt_reduce
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
// moe_align_block_size
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
int64_t
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// warp_reduce
m
.
def
(
"reduce"
,
&
warp_reduce
,
"Warp Reduce (CUDA)"
);
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"dispose"
,
&
dispose
,
"dispose custom allreduce meta"
);
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
// moe_align_block_size
m
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"MOE Align Block Size (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
deleted
100644 → 0
View file @
bdd2827a
#include <torch/extension.h>
#include "utils.hpp"
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
);
torch
::
Tensor
warp_reduce
(
torch
::
Tensor
input
)
{
CHECK_CUDA_INPUT
(
input
);
return
warp_reduce_cuda
(
input
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"reduce"
,
&
warp_reduce
,
"Warp Reduce (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu
View file @
b02da24a
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "utils.hpp"
#define FINAL_MASK 0xffffffff
#define BLOCK_SIZE 256
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
b02da24a
from
.moe_align_block_size
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
init_custom_ar
as
_init_custom_ar
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
reduce
as
_reduce
def
warp_reduce
(
input_tensor
):
return
_reduce
(
input_tensor
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
):
return
_init_custom_ar
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
)
def
custom_dispose
(
fa
):
_dispose
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
_all_reduce
(
fa
,
inp
,
out
)
def
moe_align_block_size
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment