Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
88596739
Unverified
Commit
88596739
authored
Oct 28, 2025
by
weiliang
Committed by
GitHub
Oct 27, 2025
Browse files
Support running FP4 Deepseek on SM120. (#11708)
parent
a6ea3add
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
33 additions
and
35 deletions
+33
-35
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-2
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+8
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-5
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-10
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-3
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+10
-1
sgl-kernel/tests/test_fp8_blockwise_moe.py
sgl-kernel/tests/test_fp8_blockwise_moe.py
+3
-3
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
88596739
...
@@ -26,8 +26,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
...
@@ -26,8 +26,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_int_env_var
,
get_int_env_var
,
is_blackwell_supported
,
is_flashinfer_available
,
is_flashinfer_available
,
is_sm100_supported
,
next_power_of_2
,
next_power_of_2
,
)
)
...
@@ -229,7 +229,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -229,7 +229,7 @@ class FlashInferAttnBackend(AttentionBackend):
]
]
fmha_backend
=
"auto"
fmha_backend
=
"auto"
if
is_
sm100
_supported
():
if
is_
blackwell
_supported
():
# Disable CUTLASS backend when piecewise cuda graph is enabled
# Disable CUTLASS backend when piecewise cuda graph is enabled
# due to TMA descriptor initialization issues on B200
# due to TMA descriptor initialization issues on B200
if
model_runner
.
server_args
.
enable_piecewise_cuda_graph
:
if
model_runner
.
server_args
.
enable_piecewise_cuda_graph
:
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
88596739
...
@@ -25,8 +25,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
...
@@ -25,8 +25,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
is_blackwell_supported
,
is_flashinfer_available
,
is_flashinfer_available
,
is_sm100_supported
,
next_power_of_2
,
next_power_of_2
,
)
)
...
@@ -243,7 +243,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -243,7 +243,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self
.
q_indptr_decode
=
q_indptr_decode_buf
self
.
q_indptr_decode
=
q_indptr_decode_buf
self
.
fmha_backend
=
"auto"
self
.
fmha_backend
=
"auto"
if
is_
sm100
_supported
():
if
is_
blackwell
_supported
():
self
.
fmha_backend
=
"cutlass"
self
.
fmha_backend
=
"cutlass"
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
backend
=
self
.
fmha_backend
self
.
workspace_buffer
,
"NHD"
,
backend
=
self
.
fmha_backend
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
88596739
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.utils
import
ceil_div
,
is_
sm100
_supported
,
offloader
from
sglang.srt.utils
import
ceil_div
,
is_
blackwell
_supported
,
offloader
try
:
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
...
@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
ENABLE_FLASHINFER_GEMM
=
(
ENABLE_FLASHINFER_GEMM
=
(
get_bool_env_var
(
"SGLANG_ENABLE_FLASHINFER_GEMM"
)
get_bool_env_var
(
"SGLANG_ENABLE_FLASHINFER_GEMM"
)
and
is_
sm100
_supported
()
and
is_
blackwell
_supported
()
and
is_flashinfer_available
()
and
is_flashinfer_available
()
)
)
if
ENABLE_FLASHINFER_GEMM
:
if
ENABLE_FLASHINFER_GEMM
:
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
88596739
...
@@ -28,7 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -28,7 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
is_
sm100
_supported
,
is_
blackwell
_supported
,
)
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
...
@@ -49,8 +49,10 @@ if TYPE_CHECKING:
...
@@ -49,8 +49,10 @@ if TYPE_CHECKING:
)
)
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
if
is_cuda
():
try
:
from
sgl_kernel
import
scaled_fp4_quant
from
flashinfer
import
fp4_quantize
except
ImportError
:
fp4_quantize
=
None
try
:
try
:
from
flashinfer
import
mm_fp4
as
fp4_gemm
from
flashinfer
import
mm_fp4
as
fp4_gemm
...
@@ -867,10 +869,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
...
@@ -867,10 +869,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
output_shape
=
[
x_m
,
w_n
]
output_shape
=
[
x_m
,
w_n
]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_scale_interleaved
=
scaled_
fp4_quant
(
x
,
layer
.
input_scale_inv
)
x_fp4
,
x_scale_interleaved
=
fp4_quant
ize
(
x
,
layer
.
input_scale_inv
)
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
x_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight
.
dtype
==
torch
.
uint8
assert
layer
.
weight
.
dtype
==
torch
.
uint8
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
assert
layer
.
alpha
.
dtype
==
torch
.
float32
...
@@ -903,7 +904,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -903,7 +904,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptFp4Config
):
def
__init__
(
self
,
quant_config
:
ModelOptFp4Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
if
not
is_
sm100
_supported
():
if
not
is_
blackwell
_supported
():
raise
ValueError
(
raise
ValueError
(
"Current platform does not support NVFP4"
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" quantization. Please use Blackwell and"
...
@@ -1410,7 +1411,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1410,7 +1411,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
output_dtype
=
x
.
dtype
output_dtype
=
x
.
dtype
x_sf
=
None
x_sf
=
None
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
from
flashinfer
import
fp4_quantize
,
nvfp4_block_scale_interleave
from
flashinfer
import
nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
# Quantize before comm, swizzle after.
if
x
.
shape
[
0
]
>
0
:
if
x
.
shape
[
0
]
>
0
:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
88596739
...
@@ -131,13 +131,11 @@ from sglang.srt.utils import (
...
@@ -131,13 +131,11 @@ from sglang.srt.utils import (
get_int_env_var
,
get_int_env_var
,
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_flashinfer_available
,
is_gfx95_supported
,
is_gfx95_supported
,
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
is_non_idle_and_non_empty
,
is_npu
,
is_npu
,
is_nvidia_cublas_cu12_version_ge_12_9
,
is_nvidia_cublas_cu12_version_ge_12_9
,
is_sm100_supported
,
log_info_on_rank0
,
log_info_on_rank0
,
make_layers
,
make_layers
,
use_intel_amx_backend
,
use_intel_amx_backend
,
...
@@ -197,8 +195,6 @@ elif _is_npu:
...
@@ -197,8 +195,6 @@ elif _is_npu:
else
:
else
:
pass
pass
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
_is_cublas_ge_129
=
is_nvidia_cublas_cu12_version_ge_12_9
()
_is_cublas_ge_129
=
is_nvidia_cublas_cu12_version_ge_12_9
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1260,7 +1256,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1260,7 +1256,7 @@ class DeepseekV2AttentionMLA(nn.Module):
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
0
]
==
2112
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
0
]
==
2112
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
1
]
==
7168
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
1
]
==
7168
and
_is_cuda
and
_is_cuda
and
_device_sm
>=
9
0
and
90
<=
_device_sm
<
12
0
)
)
self
.
qkv_proj_with_rope_is_int8
=
(
self
.
qkv_proj_with_rope_is_int8
=
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
88596739
...
@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
...
@@ -70,18 +70,9 @@ from sglang.srt.models.utils import (
enable_fused_set_kv_buffer
,
enable_fused_set_kv_buffer
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
LazyValue
,
add_prefix
,
is_cuda
,
make_layers
LazyValue
,
add_prefix
,
is_cuda
,
is_flashinfer_available
,
is_sm100_supported
,
make_layers
,
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
if
_is_cuda
:
if
_is_cuda
:
...
...
python/sglang/srt/server_args.py
View file @
88596739
...
@@ -39,6 +39,7 @@ from sglang.srt.utils.common import (
...
@@ -39,6 +39,7 @@ from sglang.srt.utils.common import (
get_device
,
get_device
,
get_device_memory_capacity
,
get_device_memory_capacity
,
get_device_sm
,
get_device_sm
,
is_blackwell_supported
,
is_cuda
,
is_cuda
,
is_fa3_default_architecture
,
is_fa3_default_architecture
,
is_flashinfer_available
,
is_flashinfer_available
,
...
@@ -913,7 +914,7 @@ class ServerArgs:
...
@@ -913,7 +914,7 @@ class ServerArgs:
f
"- Decode:
{
decode_attn_backend
}
\n
"
f
"- Decode:
{
decode_attn_backend
}
\n
"
)
)
if
is_
sm100
_supported
():
if
is_
blackwell
_supported
():
if
not
self
.
enable_dp_attention
:
if
not
self
.
enable_dp_attention
:
self
.
enable_flashinfer_allreduce_fusion
=
True
self
.
enable_flashinfer_allreduce_fusion
=
True
logger
.
info
(
logger
.
info
(
...
@@ -925,7 +926,7 @@ class ServerArgs:
...
@@ -925,7 +926,7 @@ class ServerArgs:
and
quantization_config
.
get
(
"quant_method"
)
==
"mxfp4"
and
quantization_config
.
get
(
"quant_method"
)
==
"mxfp4"
)
)
if
is_
sm100
_supported
()
and
is_mxfp4_quant_format
:
if
is_
blackwell
_supported
()
and
is_mxfp4_quant_format
:
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
self
.
moe_runner_backend
=
"flashinfer_mxfp4"
logger
.
warning
(
logger
.
warning
(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
...
@@ -1145,7 +1146,7 @@ class ServerArgs:
...
@@ -1145,7 +1146,7 @@ class ServerArgs:
self
.
attention_backend
==
"trtllm_mla"
self
.
attention_backend
==
"trtllm_mla"
or
self
.
decode_attention_backend
==
"trtllm_mla"
or
self
.
decode_attention_backend
==
"trtllm_mla"
):
):
if
not
is_
sm100
_supported
():
if
not
is_
blackwell
_supported
():
raise
ValueError
(
raise
ValueError
(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
)
)
...
...
python/sglang/srt/utils/common.py
View file @
88596739
...
@@ -188,7 +188,16 @@ is_hopper_with_cuda_12_3 = lambda: _check(9)
...
@@ -188,7 +188,16 @@ is_hopper_with_cuda_12_3 = lambda: _check(9)
def
is_blackwell
():
def
is_blackwell
():
if
not
is_cuda
():
if
not
is_cuda
():
return
False
return
False
return
torch
.
cuda
.
get_device_capability
()[
0
]
==
10
return
torch
.
cuda
.
get_device_capability
()[
0
]
in
[
10
,
12
]
@
lru_cache
(
maxsize
=
1
)
def
is_blackwell_supported
(
device
=
None
)
->
bool
:
if
not
is_cuda_alike
():
return
False
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
in
[
10
,
12
])
and
(
torch
.
version
.
cuda
>=
"12.8"
)
@
lru_cache
(
maxsize
=
1
)
@
lru_cache
(
maxsize
=
1
)
...
...
sgl-kernel/tests/test_fp8_blockwise_moe.py
View file @
88596739
...
@@ -86,8 +86,8 @@ def baseline_scaled_mm(
...
@@ -86,8 +86,8 @@ def baseline_scaled_mm(
).
to
(
out_dtype
)
).
to
(
out_dtype
)
def
is_
sm100
_supported
(
device
=
None
)
->
bool
:
def
is_
blackwell
_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
in
[
10
,
12
]
)
and
(
torch
.
version
.
cuda
>=
"12.8"
torch
.
version
.
cuda
>=
"12.8"
)
)
...
@@ -99,7 +99,7 @@ def is_sm90_supported(device=None) -> bool:
...
@@ -99,7 +99,7 @@ def is_sm90_supported(device=None) -> bool:
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
(
is_
sm100
_supported
()
or
is_sm90_supported
()),
not
(
is_
blackwell
_supported
()
or
is_sm90_supported
()),
reason
=
"fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90"
,
reason
=
"fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90"
,
)
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
,
64
,
128
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment