Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c2942907
Unverified
Commit
c2942907
authored
Apr 22, 2025
by
JieXin Liang
Committed by
GitHub
Apr 21, 2025
Browse files
[feature] enable pre compile jit deep_gemm (#5580)
parent
e69a2190
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
549 additions
and
45 deletions
+549
-45
python/sglang/compile_deep_gemm.py
python/sglang/compile_deep_gemm.py
+136
-0
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+378
-0
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+7
-38
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+8
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+10
-0
No files found.
python/sglang/compile_deep_gemm.py
0 → 100644
View file @
c2942907
"""
Compile DeepGEMM Kernels for a model with specify server arguments
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
It accepts server arguments (the same as launch_server.py).
Usage:
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
"""
import
argparse
import
dataclasses
import
multiprocessing
import
os
import
time
import
requests
from
sglang.srt.entrypoints.http_server
import
launch_server
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.warmup
import
warmup
multiprocessing
.
set_start_method
(
"spawn"
,
force
=
True
)
# Reduce warning
os
.
environ
[
"SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"
]
=
"1"
@
dataclasses
.
dataclass
class
CompileArgs
:
timeout
:
int
=
3600
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
"--timeout"
,
type
=
int
,
default
=
CompileArgs
.
timeout
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
# use the default value's type to cast the args into correct types.
attrs
=
[(
attr
.
name
,
type
(
attr
.
default
))
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
attr_type
(
getattr
(
args
,
attr
))
for
attr
,
attr_type
in
attrs
}
)
@
warmup
(
"compile-deep-gemm"
)
async
def
warm_up_compile
(
tokenizer_manager
:
TokenizerManager
):
print
(
"
\n
Generate warm up request for compiling DeepGEMM...
\n
"
)
generate_req_input
=
GenerateReqInput
(
input_ids
=
[
0
,
1
,
2
,
3
],
sampling_params
=
{
"temperature"
:
0.0
,
"max_new_tokens"
:
8
,
"ignore_eos"
:
True
,
},
)
await
tokenizer_manager
.
generate_request
(
generate_req_input
,
None
).
__anext__
()
def
launch_server_internal
(
server_args
):
try
:
launch_server
(
server_args
)
except
Exception
as
e
:
raise
e
finally
:
kill_process_tree
(
os
.
getpid
(),
include_parent
=
False
)
def
launch_server_process_and_send_one_request
(
server_args
:
ServerArgs
,
compile_args
:
CompileArgs
):
proc
=
multiprocessing
.
Process
(
target
=
launch_server_internal
,
args
=
(
server_args
,))
proc
.
start
()
base_url
=
f
"http://
{
server_args
.
host
}
:
{
server_args
.
port
}
"
timeout
=
compile_args
.
timeout
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
timeout
:
try
:
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
,
}
response
=
requests
.
get
(
f
"
{
base_url
}
/v1/models"
,
headers
=
headers
)
if
response
.
status_code
==
200
:
return
proc
except
requests
.
RequestException
:
pass
time
.
sleep
(
10
)
raise
TimeoutError
(
"DeepGEMM Kernels compilation timeout."
"
\n\n
Feel free and please restart the command."
)
def
refine_server_args
(
server_args
:
ServerArgs
,
compile_args
:
CompileArgs
):
# Disbale cuda graph and torch compile to save time
server_args
.
disable_cuda_graph
=
True
server_args
.
enable_torch_compile
=
False
print
(
f
"Disable CUDA Graph and Torch Compile to save time..."
)
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
server_args
.
watchdog_timeout
=
compile_args
.
timeout
server_args
.
warmups
=
"compile-deep-gemm"
def
run_compile
(
server_args
:
ServerArgs
,
compile_args
:
CompileArgs
):
print
(
"Begin DeepGEMM Kernels compilation...
\n
"
"It may take a long time and timeout maybe raised "
"while the compilation is still in progress.
\n
"
"Just feel free to restart the command "
"until the compilation is fully finished.
\n
"
)
proc
=
launch_server_process_and_send_one_request
(
server_args
,
compile_args
)
kill_process_tree
(
proc
.
pid
)
print
(
"
\n
DeepGEMM Kernels compilation finished successfully."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
CompileArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
compile_args
=
CompileArgs
.
from_cli_args
(
args
)
refine_server_args
(
server_args
,
compile_args
)
run_compile
(
server_args
,
compile_args
)
python/sglang/srt/layers/quantization/deep_gemm.py
0 → 100644
View file @
c2942907
import
logging
import
os
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
tqdm.contrib.concurrent
import
thread_map
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_sm
,
get_int_env_var
,
is_cuda
_ENABLE_JIT_DEEPGEMM
=
False
if
is_cuda
():
import
deep_gemm
from
deep_gemm
import
get_num_sms
from
deep_gemm.jit_kernels.gemm
import
get_best_configs
from
deep_gemm.jit_kernels.gemm
import
includes
as
deep_gemm_includes
from
deep_gemm.jit_kernels.gemm
import
template
as
deep_gemm_gemm_template
from
deep_gemm.jit_kernels.m_grouped_gemm
import
(
template
as
deep_gemm_grouped_gemm_template
,
)
from
deep_gemm.jit_kernels.tuner
import
jit_tuner
sm_version
=
get_device_sm
()
if
sm_version
==
90
:
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"false"
):
_ENABLE_JIT_DEEPGEMM
=
True
logger
=
logging
.
getLogger
(
__name__
)
_BUILTIN_M_LIST
=
list
(
range
(
1
,
1024
*
16
+
1
))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
=
get_bool_env_var
(
"SGL_JIT_DEEPGEMM_PRECOMPILE"
,
"true"
)
_DO_COMPILE
=
get_bool_env_var
(
"SGL_IS_FIRST_RANK_ON_NODE"
,
"true"
)
_COMPILE_WORKERS
=
get_int_env_var
(
"SGL_JIT_DEEPGEMM_COMPILE_WORKERS"
,
4
)
_IN_PRE_COMPILE_STAGE
=
get_bool_env_var
(
"SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"
,
"false"
)
# Force redirect deep_gemm cache_dir
os
.
environ
[
"DG_CACHE_DIR"
]
=
os
.
getenv
(
"SGL_DG_CACHE_DIR"
,
os
.
path
.
expanduser
(
"~"
)
+
"/.cache/deep_gemm"
)
def
update_deep_gemm_config
(
gpu_id
:
int
,
server_args
:
ServerArgs
):
global
_BUILTIN_M_LIST
global
_DO_COMPILE
# Generate m_max
m_max
=
1024
*
16
if
server_args
.
chunked_prefill_size
<
1
:
m_max
=
1024
*
64
elif
server_args
.
chunked_prefill_size
>
8192
:
m_max
=
server_args
.
chunked_prefill_size
*
2
m_max
=
min
(
1024
*
128
,
m_max
)
_BUILTIN_M_LIST
=
list
(
range
(
1
,
m_max
+
1
))
# Check if is the first rank on node
_DO_COMPILE
=
ServerArgs
.
base_gpu_id
==
gpu_id
class
DeepGemmKernelType
(
IntEnum
):
GROUPED_GEMM_NT_F8F8BF16_MASKED
=
auto
()
GROUPED_GEMM_NT_F8F8BF16_CONTIG
=
auto
()
GEMM_NT_F8F8BF16
=
auto
()
@
dataclass
class
DeepGemmKernelHelper
:
name
:
str
compile_func
:
Callable
[
[
int
,
int
,
int
,
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
],
None
,
]
configure_func
:
Callable
[
[
int
,
int
,
int
,
int
,
int
],
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
]
_INITIALIZATION_DICT
:
Dict
[
Tuple
[
DeepGemmKernelType
,
int
,
int
,
int
],
bool
]
=
dict
()
def
_compile_warning_1
():
if
not
_IN_PRE_COMPILE_STAGE
:
logger
.
warning
(
"Entering DeepGEMM JIT Pre-Complie session. "
"And it may takes a long time(Typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
def
_compile_warning_2
():
logger
.
warning
(
"Entering DeepGEMM JIT Single Kernel Complie session. "
"And it will makes inference throughput becomes flaky. "
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to solve this issue. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
def
_compile_grouped_gemm_nt_f8f8bf16_masked_one
(
n
:
int
,
k
:
int
,
num_groups
:
int
,
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
# Auto-tuning with compilation
global
deep_gemm_includes
,
deep_gemm_grouped_gemm_template
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
_
=
jit_tuner
.
compile_and_tune
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
keys
=
{
"N"
:
n
,
"K"
:
k
,
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"SWIZZLE_D_MODE"
:
smem_config
[
1
],
"BLOCK_N_PADDING"
:
smem_config
[
2
],
"NUM_GROUPS"
:
num_groups
,
"NUM_STAGES"
:
num_stages
,
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"GEMM_TYPE"
:
"GroupedMasked"
,
},
space
=
(),
includes
=
deep_gemm_includes
,
arg_defs
=
(
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"grouped_layout"
,
torch
.
int32
),
(
"m"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_grouped_gemm_template
,
args
=
[],
)
def
_compile_grouped_gemm_nt_f8f8bf16_contig_one
(
n
:
int
,
k
:
int
,
num_groups
:
int
,
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
global
deep_gemm_includes
,
deep_gemm_grouped_gemm_template
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
_
=
jit_tuner
.
compile_and_tune
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt"
,
keys
=
{
"N"
:
n
,
"K"
:
k
,
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"SWIZZLE_D_MODE"
:
smem_config
[
1
],
"BLOCK_N_PADDING"
:
smem_config
[
2
],
"NUM_GROUPS"
:
num_groups
,
"NUM_STAGES"
:
num_stages
,
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
"GEMM_TYPE"
:
"GroupedContiguous"
,
},
space
=
(),
includes
=
deep_gemm_includes
,
arg_defs
=
(
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"grouped_layout"
,
torch
.
int32
),
(
"m"
,
int
),
(
"num_groups"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_grouped_gemm_template
,
args
=
[],
)
def
_compile_gemm_nt_f8f8bf16_one
(
n
:
int
,
k
:
int
,
_
:
int
,
# _ is a dummy parameter to align with other interfaces
config
:
Tuple
[
int
,
int
,
int
,
int
,
Tuple
[
int
,
bool
],
Tuple
[
int
,
int
,
int
]],
)
->
None
:
global
deep_gemm_includes
,
deep_gemm_gemm_template
_
,
block_m
,
block_n
,
num_stages
,
tma_multicast_config
,
smem_config
=
config
_
=
jit_tuner
.
compile_and_tune
(
name
=
"gemm_fp8_fp8_bf16_nt"
,
keys
=
{
"N"
:
n
,
"K"
:
k
,
"BLOCK_M"
:
block_m
,
"BLOCK_N"
:
block_n
,
"SWIZZLE_D_MODE"
:
smem_config
[
1
],
"BLOCK_N_PADDING"
:
smem_config
[
2
],
"NUM_STAGES"
:
num_stages
,
"NUM_TMA_MULTICAST"
:
tma_multicast_config
[
0
],
"IS_TMA_MULTICAST_ON_A"
:
tma_multicast_config
[
1
],
},
space
=
(),
includes
=
deep_gemm_includes
,
arg_defs
=
(
(
"lhs"
,
torch
.
float8_e4m3fn
),
(
"lhs_scales"
,
torch
.
float
),
(
"rhs"
,
torch
.
float8_e4m3fn
),
(
"rhs_scales"
,
torch
.
float
),
(
"out"
,
torch
.
bfloat16
),
(
"m"
,
int
),
(
"stream"
,
torch
.
cuda
.
Stream
),
(
"num_sms"
,
int
),
(
"smem_size"
,
int
),
),
template
=
deep_gemm_gemm_template
,
args
=
[],
)
_KERNEL_HELPER_DICT
:
Dict
[
DeepGemmKernelType
,
DeepGemmKernelHelper
]
=
{
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
:
DeepGemmKernelHelper
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
,
compile_func
=
_compile_grouped_gemm_nt_f8f8bf16_masked_one
,
configure_func
=
lambda
m
,
n
,
k
,
num_groups
,
num_sms
:
get_best_configs
(
m
,
n
,
k
,
num_groups
,
num_sms
,
is_grouped_masked
=
True
),
),
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_CONTIG
:
DeepGemmKernelHelper
(
name
=
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"
,
compile_func
=
_compile_grouped_gemm_nt_f8f8bf16_contig_one
,
configure_func
=
lambda
m
,
n
,
k
,
_
,
num_sms
:
get_best_configs
(
m
,
n
,
k
,
1
,
num_sms
,
is_grouped_contiguous
=
True
),
),
DeepGemmKernelType
.
GEMM_NT_F8F8BF16
:
DeepGemmKernelHelper
(
name
=
"gemm_fp8_fp8_bf16_nt"
,
compile_func
=
_compile_gemm_nt_f8f8bf16_one
,
configure_func
=
lambda
m
,
n
,
k
,
_
,
num_sms
:
get_best_configs
(
m
,
n
,
k
,
1
,
num_sms
),
),
}
def
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
:
DeepGemmKernelType
,
n
:
int
,
k
:
int
,
num_groups
:
int
,
m_list
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
global
_INITIALIZATION_DICT
global
_BUILTIN_M_LIST
query_key
=
(
kernel_type
,
n
,
k
,
num_groups
)
if
(
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and
_DO_COMPILE
and
_INITIALIZATION_DICT
.
get
(
query_key
)
is
None
):
_INITIALIZATION_DICT
[
query_key
]
=
True
kernel_helper
=
_KERNEL_HELPER_DICT
[
kernel_type
]
_compile_warning_1
()
logger
.
info
(
f
"Try DeepGEMM JIT Compiling for "
f
"<
{
kernel_helper
.
name
}
> N=
{
n
}
, K=
{
k
}
, num_groups=
{
num_groups
}
with all Ms."
f
"
{
' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. '
if
not
_IN_PRE_COMPILE_STAGE
else
''
}
"
)
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
num_sms
=
get_num_sms
()
collected_configs
=
set
()
for
m
in
m_list
if
m_list
is
not
None
else
_BUILTIN_M_LIST
:
# Put config into set to get unique configs and reduce cases to be compiled
collected_configs
.
add
(
kernel_helper
.
configure_func
(
m
,
n
,
k
,
num_groups
,
num_sms
)
)
compile_func
=
lambda
config
:
kernel_helper
.
compile_func
(
n
,
k
,
num_groups
,
config
)
thread_map
(
compile_func
,
collected_configs
,
max_workers
=
_COMPILE_WORKERS
)
def
grouped_gemm_nt_f8f8bf16_masked
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m
:
int
,
):
num_groups
,
_
,
k
=
lhs
[
0
].
shape
_
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_MASKED
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
expected_m
,
n
,
k
,
kernel_type
):
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
lhs
,
rhs
,
out
,
masked_m
,
expected_m
)
def
grouped_gemm_nt_f8f8bf16_contig
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
num_groups
,
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GROUPED_GEMM_NT_F8F8BF16_CONTIG
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
num_groups
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
deep_gemm
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
lhs
,
rhs
,
out
,
m_indices
)
def
gemm_nt_f8f8bf16
(
lhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
rhs
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
out
:
torch
.
Tensor
,
):
m
,
k
=
lhs
[
0
].
shape
n
,
_
=
rhs
[
0
].
shape
kernel_type
=
DeepGemmKernelType
.
GEMM_NT_F8F8BF16
_maybe_compile_deep_gemm_one_type_all
(
kernel_type
,
n
,
k
,
1
)
with
_log_jit_build
(
m
,
n
,
k
,
kernel_type
):
deep_gemm
.
gemm_fp8_fp8_bf16_nt
(
lhs
,
rhs
,
out
)
@
contextmanager
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
,
kernel_type
:
DeepGemmKernelType
):
if
_IN_PRE_COMPILE_STAGE
:
yield
return
from
deep_gemm.jit.runtime
import
RuntimeCache
origin_func
=
RuntimeCache
.
__getitem__
def
__patched_func
(
self
,
*
args
,
**
kwargs
):
ret
=
origin_func
(
self
,
*
args
,
**
kwargs
)
if
ret
is
None
:
kernel_helper
=
_KERNEL_HELPER_DICT
[
kernel_type
]
_compile_warning_2
()
logger
.
warning
(
f
"DeepGEMM JIT Compiling for <
{
kernel_helper
.
name
}
> M=
{
M
}
, N=
{
N
}
, K=
{
K
}
. Please wait."
)
return
ret
RuntimeCache
.
__getitem__
=
__patched_func
yield
RuntimeCache
.
__getitem__
=
origin_func
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
c2942907
...
...
@@ -16,19 +16,17 @@ import functools
import
json
import
logging
import
os
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
get_device_core_count
,
get_device_name
,
get_device_sm
,
is_cuda
,
is_hip
,
supports_custom_op
,
...
...
@@ -43,22 +41,16 @@ else:
fp8_max
=
torch
.
finfo
(
_fp8_type
).
max
fp8_min
=
-
fp8_max
_enable_jit_deepgemm
=
False
_enable_jit_deepgemm_bmm
=
False
if
_is_cuda
:
import
deep_gemm
from
sgl_kernel
import
(
sgl_per_tensor_quant_fp8
,
sgl_per_token_group_quant_fp8
,
sgl_per_token_quant_fp8
,
)
sm_version
=
get_device_sm
()
if
sm_version
==
90
:
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM"
,
default
=
"false"
):
_enable_jit_deepgemm
=
True
if
get_bool_env_var
(
"SGL_ENABLE_JIT_DEEPGEMM_BMM"
,
default
=
"false"
):
_enable_jit_deepgemm_bmm
=
True
from
sglang.srt.layers.quantization.deep_gemm
import
(
gemm_nt_f8f8bf16
as
deep_gemm_gemm_nt_f8f8bf16
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -71,10 +63,7 @@ if supports_custom_op():
Bs
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
)
->
None
:
M
,
K
=
A
.
shape
N
,
_
=
B
.
shape
with
_log_jit_build
(
M
,
N
,
K
):
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
def
deep_gemm_fp8_fp8_bf16_nt_fake
(
A
:
torch
.
Tensor
,
...
...
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
return
None
@
contextmanager
def
_log_jit_build
(
M
:
int
,
N
:
int
,
K
:
int
):
from
deep_gemm.jit.runtime
import
RuntimeCache
origin_func
=
RuntimeCache
.
__getitem__
def
__patched_func
(
self
,
*
args
,
**
kwargs
):
ret
=
origin_func
(
self
,
*
args
,
**
kwargs
)
if
ret
is
None
:
logger
.
warning
(
f
"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M=
{
M
}
, N=
{
N
}
, K=
{
K
}
. Please wait."
)
return
ret
RuntimeCache
.
__getitem__
=
__patched_func
yield
RuntimeCache
.
__getitem__
=
origin_func
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
)
# deepgemm only support bf16
if
C
.
dtype
==
torch
.
bfloat16
and
_
enable_jit_deepgemm
:
if
C
.
dtype
==
torch
.
bfloat16
and
_
ENABLE_JIT_DEEPGEMM
:
if
supports_custom_op
():
torch
.
ops
.
sglang
.
deep_gemm_fp8_fp8_bf16_nt
(
A
,
As
,
B
,
Bs
,
C
)
else
:
with
_log_jit_build
(
M
,
N
,
K
):
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
A
,
As
),
(
B
,
Bs
),
C
)
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
else
:
kernel
=
(
_w8a8_block_fp8_matmul_unrolledx4
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
c2942907
...
...
@@ -12,8 +12,8 @@ try:
except
ImportError
:
VLLM_AVAILABLE
=
False
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
_enable_jit_deepgemm
,
per_token_group_quant_fp8
,
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
...
...
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
)
gemm_a8w8_blockscale
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
)
else
:
if
_
enable_jit_deepgemm
:
if
_
ENABLE_JIT_DEEPGEMM
:
q_input
,
x_scale
=
sglang_per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
c2942907
...
...
@@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.quantization
import
monkey_patch_isinstance_for_vllm_base_layer
from
sglang.srt.layers.quantization.deep_gemm
import
(
_ENABLE_JIT_DEEPGEMM
,
update_deep_gemm_config
,
)
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_to_model
from
sglang.srt.lora.lora_manager
import
LoRAManager
...
...
@@ -169,6 +173,10 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory
=
self
.
init_torch_distributed
()
# Update deep gemm configure
if
_ENABLE_JIT_DEEPGEMM
:
update_deep_gemm_config
(
gpu_id
,
server_args
)
# If it is a draft model tp_group can be different.
self
.
initialize
(
min_per_gpu_memory
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
c2942907
...
...
@@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
_enable_jit_deepgemm_bmm
,
per_tensor_quant_mla_deep_gemm_masked_fp8
,
per_tensor_quant_mla_fp8
,
)
...
...
@@ -86,8 +86,11 @@ _is_hip = is_hip()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
deep_gemm
import
m_grouped_gemm_fp8_fp8_bf16_nt_masked
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
from
sglang.srt.layers.quantization.deep_gemm
import
(
grouped_gemm_nt_f8f8bf16_masked
as
deep_gemm_grouped_gemm_nt_f8f8bf16_masked
,
)
else
:
from
vllm._custom_ops
import
awq_dequantize
...
...
@@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
kv_lora_rank
)
)
m_grouped_gemm_
fp8_fp8_
bf16_
nt_
masked
(
deep_gem
m_grouped_gemm_
nt_f8f8
bf16_masked
(
(
q_nope_val
,
q_nope_scale
),
(
self
.
w_kc
,
self
.
w_scale_k
),
q_nope_out
,
...
...
@@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output
=
attn_output
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
v_head_dim
)
)
m_grouped_gemm_
fp8_fp8_
bf16_
nt_
masked
(
deep_gem
m_grouped_gemm_
nt_f8f8
bf16_masked
(
(
attn_output_val
,
attn_output_scale
),
(
self
.
w_vc
,
self
.
w_scale_v
),
attn_bmm_output
,
...
...
@@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if
(
_is_cuda
and
_
enable_jit_deepgemm_bmm
and
_
ENABLE_JIT_DEEPGEMM
and
weight_block_size
[
0
]
==
128
and
weight_block_size
[
1
]
==
128
and
model_dtype
==
torch
.
bfloat16
...
...
python/sglang/srt/utils.py
View file @
c2942907
...
...
@@ -98,6 +98,16 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return
value
in
truthy_values
def
get_int_env_var
(
name
:
str
,
default
:
int
=
0
)
->
int
:
value
=
os
.
getenv
(
name
)
if
value
is
None
or
not
value
.
strip
():
return
default
try
:
return
int
(
value
)
except
ValueError
:
return
default
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment