Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6619f48e
"sgl-kernel/python/sgl_kernel/__init__.py" did not exist on "b02da24a5b8cc0b8e4971f59a7e0f8afcfeab9b3"
Unverified
Commit
6619f48e
authored
Jan 24, 2025
by
Ke Bao
Committed by
GitHub
Jan 24, 2025
Browse files
Fix cu118 group gemm compile issue (#3097)
parent
3ed0a547
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
17 deletions
+21
-17
sgl-kernel/setup.py
sgl-kernel/setup.py
+21
-17
No files found.
sgl-kernel/setup.py
View file @
6619f48e
...
@@ -62,6 +62,23 @@ nvcc_flags = [
...
@@ -62,6 +62,23 @@ nvcc_flags = [
"-DFLASHINFER_ENABLE_F16"
,
"-DFLASHINFER_ENABLE_F16"
,
]
]
sources
=
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/group_gemm.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
]
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
...
@@ -71,6 +88,7 @@ sm_version = _get_device_sm()
...
@@ -71,6 +88,7 @@ sm_version = _get_device_sm()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
if
sm_version
>=
90
:
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
nvcc_flags
.
extend
(
[
[
...
@@ -85,6 +103,7 @@ else:
...
@@ -85,6 +103,7 @@ else:
# compilation environment without GPU
# compilation environment without GPU
if
enable_sm90a
:
if
enable_sm90a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
sources
.
append
(
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
)
if
enable_fp8
:
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags
.
extend
(
[
[
...
@@ -110,26 +129,11 @@ for flag in [
...
@@ -110,26 +129,11 @@ for flag in [
cxx_flags
=
[
"-O3"
]
cxx_flags
=
[
"-O3"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
,
"cuda"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
,
"cuda"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
"-L/usr/lib/x86_64-linux-gnu"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
"-L/usr/lib/x86_64-linux-gnu"
]
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
CUDAExtension
(
name
=
"sgl_kernel.ops._kernels"
,
name
=
"sgl_kernel.ops._kernels"
,
sources
=
[
sources
=
sources
,
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/bmm_fp8.cu"
,
"3rdparty/flashinfer/csrc/group_gemm.cu"
,
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu"
,
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
],
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
extra_compile_args
=
{
extra_compile_args
=
{
"nvcc"
:
nvcc_flags
,
"nvcc"
:
nvcc_flags
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment