Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
23010630
Unverified
Commit
23010630
authored
May 11, 2025
by
Yineng Zhang
Committed by
GitHub
May 11, 2025
Browse files
chore: upgrade sgl-kernel v0.1.2.post1 (#6196)
Co-authored-by:
alcanderian
<
alcanderian@gmail.com
>
parent
45b4dcf0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
61 additions
and
71 deletions
+61
-71
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+1
-1
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+57
-67
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
scripts/ci_install_dependency_8_gpu.sh
scripts/ci_install_dependency_8_gpu.sh
+1
-1
No files found.
python/pyproject.toml
View file @
23010630
...
@@ -48,7 +48,7 @@ runtime_common = [
...
@@ -48,7 +48,7 @@ runtime_common = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.1.1"
,
"sgl-kernel==0.1.
2.post
1"
,
"flashinfer_python==0.2.5"
,
"flashinfer_python==0.2.5"
,
"torch==2.6.0"
,
"torch==2.6.0"
,
"torchvision==0.21.0"
,
"torchvision==0.21.0"
,
...
...
python/sglang/srt/entrypoints/engine.py
View file @
23010630
...
@@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
_is_cuda
:
if
_is_cuda
:
assert_pkg_version
(
assert_pkg_version
(
"sgl-kernel"
,
"sgl-kernel"
,
"0.1.1"
,
"0.1.
2.post
1"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`"
,
)
)
...
...
python/sglang/srt/layers/quantization/deep_gemm.py
View file @
23010630
...
@@ -16,11 +16,7 @@ if is_cuda():
...
@@ -16,11 +16,7 @@ if is_cuda():
import
deep_gemm
import
deep_gemm
from
deep_gemm
import
get_num_sms
from
deep_gemm
import
get_num_sms
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.gemm
import
includes
as
deep_gemm_includes
from
deep_gemm.jit_kernels.runtime
import
FP8GemmRuntime
,
GemmType
from
deep_gemm.jit_kernels.gemm
import
template
as
deep_gemm_gemm_template
from
deep_gemm.jit_kernels.m_grouped_gemm
import
(
template
as
deep_gemm_grouped_gemm_template
,
)
from
deep_gemm.jit_kernels.tuner
import
jit_tuner
from
deep_gemm.jit_kernels.tuner
import
jit_tuner
sm_version
=
get_device_sm
()
sm_version
=
get_device_sm
()
...
@@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
...
@@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRECOMPILE_STAGE
=
get_bool_env_var
(
"SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"
,
"false"
)
_IN_PRECOMPILE_STAGE
=
get_bool_env_var
(
"SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"
,
"false"
)
# Force redirect deep_gemm cache_dir
# Force redirect deep_gemm cache_dir
os
.
environ
[
"DG_CACHE_DIR"
]
=
os
.
getenv
(
os
.
environ
[
"DG_
JIT_
CACHE_DIR"
]
=
os
.
getenv
(
"SGL_DG_CACHE_DIR"
,
os
.
path
.
expanduser
(
"~"
)
+
"
/
.cache
/
deep_gemm"
"SGL_DG_CACHE_DIR"
,
os
.
path
.
join
(
os
.
path
.
expanduser
(
"~"
)
,
".cache
"
,
"
deep_gemm"
)
)
)
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
os
.
environ
[
"DG_JIT_USE_NVRTC"
]
=
os
.
getenv
(
"SGL_DG_USE_NVRTC"
,
"0"
)
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
global
_BUILTIN_M_LIST
global
_BUILTIN_M_LIST
...
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
...
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
num_groups
:
int
,
num_groups
:
int
,
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
)
->
None
:
# Auto-tuning with compilation
num_sms
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
global
deep_gemm_includes
,
deep_gemm_grouped_gemm_template
block_k
=
128
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
num_tma_threads
=
128
_
=
jit_tuner
.
compile_and_tune
(
num_math_threads_per_group
=
128
kwargs
=
{
"NUM_TMA_THREADS"
:
num_tma_threads
,
"NUM_MATH_THREADS_PER_GROUP"
:
num_math_threads_per_group
,
"BLOCK_K"
:
block_k
,
"NUM_SMS"
:
num_sms
,
"SMEM_SIZE"
:
smem_config
[
0
],
}
_
,
_
=
jit_tuner
.
compile_and_tune
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
keys
=
{
keys
=
{
"N"
:
n
,
"N"
:
n
,
...
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
...
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
"NUM_STAGES"
:
num_stages
,
"NUM_STAGES"
:
num_stages
,
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"GEMM_TYPE"
:
"
GroupedMasked
"
,
"GEMM_TYPE"
:
GemmType
.
GroupedMasked
,
},
},
space
=
(),
space
=
(),
includes
=
deep_gemm_includes
,
kwargs
=
kwargs
,
arg_defs
=
(
runtime_cls
=
FP8GemmRuntime
,
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"grouped_layout"
,
torch
.
int32
),
(
"m"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_grouped_gemm_template
,
args
=
[],
)
)
...
@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
...
@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
num_groups
:
int
,
num_groups
:
int
,
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
)
->
None
:
global
deep_gemm_includes
,
deep_gemm_grouped_gemm_template
num_sms
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
block_k
=
128
_
=
jit_tuner
.
compile_and_tune
(
num_tma_threads
=
128
num_math_threads_per_group
=
128
kwargs
=
{
"NUM_TMA_THREADS"
:
num_tma_threads
,
"NUM_MATH_THREADS_PER_GROUP"
:
num_math_threads_per_group
,
"BLOCK_K"
:
block_k
,
"NUM_SMS"
:
num_sms
,
"SMEM_SIZE"
:
smem_config
[
0
],
}
_
,
_
=
jit_tuner
.
compile_and_tune
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
keys
=
{
keys
=
{
"N"
:
n
,
"N"
:
n
,
...
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
...
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
"NUM_STAGES"
:
num_stages
,
"NUM_STAGES"
:
num_stages
,
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"GEMM_TYPE"
:
"
GroupedContiguous
"
,
"GEMM_TYPE"
:
GemmType
.
GroupedContiguous
,
},
},
space
=
(),
space
=
(),
includes
=
deep_gemm_includes
,
kwargs
=
kwargs
,
arg_defs
=
(
runtime_cls
=
FP8GemmRuntime
,
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"grouped_layout"
,
torch
.
int32
),
(
"m"
,
int
),
(
"num_groups"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_grouped_gemm_template
,
args
=
[],
)
)
...
@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
...
@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
_
:
int
,
# _ is a dummy parameter to align with other interfaces
_
:
int
,
# _ is a dummy parameter to align with other interfaces
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
)
->
None
:
global
deep_gemm_includes
,
deep_gemm_gemm_template
num_sms
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
block_k
=
128
_
=
jit_tuner
.
compile_and_tune
(
num_tma_threads
=
128
num_math_threads_per_group
=
128
kwargs
=
{
"GEMM_TYPE"
:
GemmType
.
Normal
,
"NUM_TMA_THREADS"
:
num_tma_threads
,
"NUM_MATH_THREADS_PER_GROUP"
:
num_math_threads_per_group
,
"NUM_GROUPS"
:
1
,
"BLOCK_K"
:
block_k
,
"NUM_SMS"
:
num_sms
,
"SMEM_SIZE"
:
smem_config
[
0
],
}
_
,
_
=
jit_tuner
.
compile_and_tune
(
name
=
"gemm_fp8_fp8_bf16_nt"
,
name
=
"gemm_fp8_fp8_bf16_nt"
,
keys
=
{
keys
=
{
"N"
:
n
,
"N"
:
n
,
...
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
...
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
},
},
space
=
(),
space
=
(),
includes
=
deep_gemm_includes
,
kwargs
=
kwargs
,
arg_defs
=
(
runtime_cls
=
FP8GemmRuntime
,
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"m"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_gemm_template
,
args
=
[],
)
)
...
@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
...
@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
from
deep_gemm.jit.runtime
import
RuntimeCache
from
deep_gemm.jit.runtime
import
RuntimeCache
origin_func
=
RuntimeCache
.
__
get
item__
origin_func
=
RuntimeCache
.
get
def
__patched_func
(
self
,
*
args
,
**
kwargs
):
def
__patched_func
(
self
,
*
args
,
**
kwargs
):
ret
=
origin_func
(
self
,
*
args
,
**
kwargs
)
ret
=
origin_func
(
self
,
*
args
,
**
kwargs
)
...
@@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
...
@@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
)
)
return
ret
return
ret
RuntimeCache
.
__
get
item__
=
__patched_func
RuntimeCache
.
get
=
__patched_func
yield
yield
RuntimeCache
.
__
get
item__
=
origin_func
RuntimeCache
.
get
=
origin_func
scripts/ci_install_dependency.sh
View file @
23010630
...
@@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
...
@@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
pip
install
--upgrade
pip
pip
install
--upgrade
pip
# Install sgl-kernel
# Install sgl-kernel
pip
install
sgl-kernel
==
0.1.1
--no-cache-dir
pip
install
sgl-kernel
==
0.1.
2.post
1
--no-cache-dir
# Install the main package
# Install the main package
pip
install
-e
"python[all]"
pip
install
-e
"python[all]"
...
...
scripts/ci_install_dependency_8_gpu.sh
View file @
23010630
...
@@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem*
...
@@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem*
pip
install
--upgrade
pip
pip
install
--upgrade
pip
# Install sgl-kernel
# Install sgl-kernel
pip
install
sgl-kernel
==
0.1.1
--no-cache-dir
pip
install
sgl-kernel
==
0.1.
2.post
1
--no-cache-dir
# Install the main package
# Install the main package
pip
install
-e
"python[all]"
pip
install
-e
"python[all]"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment