Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eebfdb94
Unverified
Commit
eebfdb94
authored
Apr 27, 2025
by
JieXin Liang
Committed by
GitHub
Apr 26, 2025
Browse files
[fix] fix potential bumpy throughtput with deepgemm (#5722)
parent
dfb32264
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
10 deletions
+17
-10
python/sglang/compile_deep_gemm.py
python/sglang/compile_deep_gemm.py
+1
-1
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+16
-9
No files found.
python/sglang/compile_deep_gemm.py
View file @
eebfdb94
...
@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup
...
@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup
multiprocessing
.
set_start_method
(
"spawn"
,
force
=
True
)
multiprocessing
.
set_start_method
(
"spawn"
,
force
=
True
)
# Reduce warning
# Reduce warning
os
.
environ
[
"SGL_IN_DEEP
_
GEMM_PRE
_
COMPILE_STAGE"
]
=
"1"
os
.
environ
[
"SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"
]
=
"1"
# Force enable deep gemm
# Force enable deep gemm
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"1"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
...
...
python/sglang/srt/layers/quantization/deep_gemm.py
View file @
eebfdb94
...
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
...
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
=
get_bool_env_var
(
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
=
get_bool_env_var
(
"SGL_JIT_DEEPGEMM_PRECOMPILE"
,
"true"
"SGL_JIT_DEEPGEMM_PRECOMPILE"
,
"true"
)
)
_DO_COMPILE
=
get_bool_env_var
(
"SGL_IS_FIRST_RANK_ON_NODE"
,
"true"
)
_DO_COMPILE_ALL
=
True
_IS_FIRST_RANK_ON_NODE
=
get_bool_env_var
(
"SGL_IS_FIRST_RANK_ON_NODE"
,
"true"
)
_COMPILE_WORKERS
=
get_int_env_var
(
"SGL_JIT_DEEPGEMM_COMPILE_WORKERS"
,
4
)
_COMPILE_WORKERS
=
get_int_env_var
(
"SGL_JIT_DEEPGEMM_COMPILE_WORKERS"
,
4
)
_IN_PRE
_
COMPILE_STAGE
=
get_bool_env_var
(
"SGL_IN_DEEP
_
GEMM_PRE
_
COMPILE_STAGE"
,
"false"
)
_IN_PRECOMPILE_STAGE
=
get_bool_env_var
(
"SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"
,
"false"
)
# Force redirect deep_gemm cache_dir
# Force redirect deep_gemm cache_dir
os
.
environ
[
"DG_CACHE_DIR"
]
=
os
.
getenv
(
os
.
environ
[
"DG_CACHE_DIR"
]
=
os
.
getenv
(
...
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
...
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
global
_BUILTIN_M_LIST
global
_BUILTIN_M_LIST
global
_DO_COMPILE
global
_DO_COMPILE_ALL
global
_IS_FIRST_RANK_ON_NODE
# Generate m_max
# Generate m_max
m_max
=
1024
*
16
m_max
=
1024
*
16
...
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
...
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
m_max
=
min
(
1024
*
128
,
m_max
)
m_max
=
min
(
1024
*
128
,
m_max
)
_BUILTIN_M_LIST
=
list
(
range
(
1
,
m_max
+
1
))
_BUILTIN_M_LIST
=
list
(
range
(
1
,
m_max
+
1
))
# Check if is the first rank on node
_IS_FIRST_RANK_ON_NODE
=
ServerArgs
.
base_gpu_id
==
gpu_id
_DO_COMPILE
=
ServerArgs
.
base_gpu_id
==
gpu_id
# Check if is the first rank on node.
# Default each rank will try compile all Ms to
# load all symbols at the launch stages.
# Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL
=
_IS_FIRST_RANK_ON_NODE
or
not
_IN_PRECOMPILE_STAGE
class
DeepGemmKernelType
(
IntEnum
):
class
DeepGemmKernelType
(
IntEnum
):
...
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
...
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
def
_compile_warning_1
():
def
_compile_warning_1
():
if
not
_IN_PRE
_
COMPILE_STAGE
:
if
not
_IN_PRECOMPILE_STAGE
and
_IS_FIRST_RANK_ON_NODE
:
logger
.
warning
(
logger
.
warning
(
"Entering DeepGEMM JIT Pre-Complie session. "
"Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) "
"And it may takes a long time(Typically 10-20 mins) "
...
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
...
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
query_key
=
(
kernel_type
,
n
,
k
,
num_groups
)
query_key
=
(
kernel_type
,
n
,
k
,
num_groups
)
if
(
if
(
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and
_DO_COMPILE
and
_DO_COMPILE
_ALL
and
_INITIALIZATION_DICT
.
get
(
query_key
)
is
None
and
_INITIALIZATION_DICT
.
get
(
query_key
)
is
None
):
):
_INITIALIZATION_DICT
[
query_key
]
=
True
_INITIALIZATION_DICT
[
query_key
]
=
True
...
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
...
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
logger
.
info
(
logger
.
info
(
f
"Try DeepGEMM JIT Compiling for "
f
"Try DeepGEMM JIT Compiling for "
f
"<
{
kernel_helper
.
name
}
> N=
{
n
}
, K=
{
k
}
, num_groups=
{
num_groups
}
with all Ms."
f
"<
{
kernel_helper
.
name
}
> N=
{
n
}
, K=
{
k
}
, num_groups=
{
num_groups
}
with all Ms."
f
"
{
' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. '
if
not
_IN_PRE
_
COMPILE_STAGE
else
''
}
"
f
"
{
' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. '
if
not
_IN_PRECOMPILE_STAGE
else
''
}
"
)
)
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
...
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
...
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
@
contextmanager
@
contextmanager
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
,
kernel_type
:
DeepGemmKernelType
):
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
,
kernel_type
:
DeepGemmKernelType
):
if
_IN_PRE
_
COMPILE_STAGE
:
if
_IN_PRECOMPILE_STAGE
:
yield
yield
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment