Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07a22cbb
Unverified
Commit
07a22cbb
authored
Jan 23, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 23, 2025
Browse files
use env variable to control the build conf on the CPU build node (#3080)
parent
3d0bfa3e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
21 deletions
+49
-21
sgl-kernel/build.sh
sgl-kernel/build.sh
+3
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+46
-21
No files found.
sgl-kernel/build.sh
View file @
07a22cbb
...
...
@@ -11,6 +11,9 @@ docker run --rm \
${
PYTHON_ROOT_PATH
}
/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu
${
CUDA_VERSION
//.
}
&&
\
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' &&
\
export CUDA_VERSION=
${
CUDA_VERSION
}
&&
\
export SGL_KERNEL_ENABLE_BF16=1 &&
\
export SGL_KERNEL_ENABLE_FP8=1 &&
\
export SGL_KERNEL_ENABLE_SM90A=1 &&
\
mkdir -p /usr/lib/x86_64-linux-gnu/ &&
\
ln -s /usr/local/cuda-
${
CUDA_VERSION
}
/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so &&
\
cd /sgl-kernel &&
\
...
...
sgl-kernel/setup.py
View file @
07a22cbb
import
os
from
pathlib
import
Path
import
torch
from
setuptools
import
find_packages
,
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
version
import
__version__
root
=
Path
(
__file__
).
parent
.
resolve
()
def
update_wheel_platform_tag
():
def
_
update_wheel_platform_tag
():
wheel_dir
=
Path
(
"dist"
)
if
wheel_dir
.
exists
()
and
wheel_dir
.
is_dir
():
old_wheel
=
next
(
wheel_dir
.
glob
(
"*.whl"
))
...
...
@@ -18,21 +18,25 @@ def update_wheel_platform_tag():
old_wheel
.
rename
(
new_wheel
)
def
get_cuda_version
():
def
_
get_cuda_version
():
if
torch
.
version
.
cuda
:
return
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
return
(
0
,
0
)
def
get_device_sm
():
def
_
get_device_sm
():
if
torch
.
cuda
.
is_available
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
return
major
*
10
+
minor
return
0
cuda_version
=
get_cuda_version
()
sm_version
=
get_device_sm
()
def
_get_version
():
with
open
(
root
/
"pyproject.toml"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"version"
):
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
cutlass
=
root
/
"3rdparty"
/
"cutlass"
flashinfer
=
root
/
"3rdparty"
/
"flashinfer"
...
...
@@ -58,19 +62,39 @@ nvcc_flags = [
"-DFLASHINFER_ENABLE_F16"
,
]
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
sm_version
>=
80
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
cuda_version
=
_get_cuda_version
()
sm_version
=
_get_device_sm
()
if
torch
.
cuda
.
is_available
():
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
sm_version
>=
80
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
else
:
# compilation environment without GPU
if
enable_sm90a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_fp8
:
nvcc_flags
.
extend
(
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
enable_bf16
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
for
flag
in
[
"-D__CUDA_NO_HALF_OPERATORS__"
,
...
...
@@ -82,6 +106,7 @@ for flag in [
torch
.
utils
.
cpp_extension
.
COMMON_NVCC_FLAGS
.
remove
(
flag
)
except
ValueError
:
pass
cxx_flags
=
[
"-O3"
]
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
,
"cuda"
]
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
"-L/usr/lib/x86_64-linux-gnu"
]
...
...
@@ -116,11 +141,11 @@ ext_modules = [
setup
(
name
=
"sgl-kernel"
,
version
=
__version
__
,
version
=
_
get
_version
()
,
packages
=
find_packages
(),
package_dir
=
{
""
:
"src"
},
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
},
)
update_wheel_platform_tag
()
_
update_wheel_platform_tag
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment