Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82f39dc1
Unverified
Commit
82f39dc1
authored
Nov 05, 2025
by
Shu Wang
Committed by
GitHub
Nov 05, 2025
Browse files
Add mm_fp4 trtllm backend (#12406)
parent
627bac64
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
5 deletions
+31
-5
docs/references/environment_variables.md
docs/references/environment_variables.md
+1
-0
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+2
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+28
-5
No files found.
docs/references/environment_variables.md
View file @
82f39dc1
...
...
@@ -66,6 +66,7 @@ SGLang supports various environment variables that can be used to configure its
|
`SGLANG_MOE_PADDING`
| Enable MoE padding (sets padding size to 128 if value is
`1`
, often set to
`1`
in Docker builds) |
`0`
|
|
`SGLANG_FORCE_FP8_MARLIN`
| Force using FP8 MARLIN kernels even if other FP8 kernels are available |
`false`
|
|
`SGLANG_ENABLE_FLASHINFER_GEMM`
| Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs |
`false`
|
|
`SGLANG_FLASHINFER_FP4_GEMM_BACKEND`
| Select backend for
`mm_fp4`
on Blackwell GPUS |
`` |
| `
SGLANG_SUPPORT_CUTLASS_BLOCK_FP8
` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `
false
` |
| `
SGLANG_CUTLASS_MOE
` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `
false
` |
...
...
python/sglang/srt/environ.py
View file @
82f39dc1
...
...
@@ -198,6 +198,8 @@ class Envs:
# Flashinfer
SGLANG_IS_FLASHINFER_AVAILABLE
=
EnvBool
(
True
)
SGLANG_ENABLE_FLASHINFER_GEMM
=
EnvBool
(
False
)
# Default to the pick from flashinfer
SGLANG_FLASHINFER_FP4_GEMM_BACKEND
=
EnvStr
(
""
)
SGLANG_FLASHINFER_WORKSPACE_SIZE
=
EnvInt
(
384
*
1024
*
1024
)
# Triton
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
82f39dc1
...
...
@@ -11,6 +11,7 @@ from sglang.srt.distributed import get_tp_group
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.dp_attention
import
(
get_dp_global_num_tokens
,
get_local_dp_buffer
,
...
...
@@ -94,14 +95,12 @@ logger = logging.getLogger(__name__)
CUTEDSL_MOE_SCALAR_INPUT_SCALE
=
get_bool_env_var
(
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE"
,
"true"
)
USE_CUTLASS_BACKEND_FOR_FP4_GEMM
=
get_bool_env_var
(
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
,
"true"
)
# TODO make it true by default when the DeepEP PR is merged
CUTEDSL_MOE_NVFP4_DISPATCH
=
get_bool_env_var
(
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH"
,
"false"
)
FLASHINFER_FP4_GEMM_BACKEND
=
envs
.
SGLANG_FLASHINFER_FP4_GEMM_BACKEND
.
get
()
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES
=
[
"static"
]
...
...
@@ -1006,7 +1005,26 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
layer
.
input_scale_inv
=
Parameter
(
(
1
/
input_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
)
if
FLASHINFER_FP4_GEMM_BACKEND
==
"trtllm"
:
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
weight
=
layer
.
weight
scale
=
layer
.
weight_scale
epilogue_tile_m
=
128
weight
=
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
scale
=
(
shuffle_matrix_sf_a
(
scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
)
)
layer
.
weight_scale_interleaved
=
Parameter
(
scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
return
# Pad and blockwise interleave weight_scale
scales
=
layer
.
weight_scale
scale_ndim
=
scales
.
ndim
...
...
@@ -1056,6 +1074,11 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if
enable_flashinfer_fp4_gemm
:
w
=
layer
.
weight
.
T
w_scale_interleaved
=
layer
.
weight_scale_interleaved
.
T
# TODO(shuw@nvidia.com)
# Remove the default after flashinfer bumped to 0.5.1
backend
=
(
FLASHINFER_FP4_GEMM_BACKEND
if
FLASHINFER_FP4_GEMM_BACKEND
else
"cutlass"
)
out
=
fp4_gemm
(
x_fp4
,
w
,
...
...
@@ -1063,7 +1086,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
w_scale_interleaved
,
layer
.
alpha
,
output_dtype
,
**
(
dict
(
backend
=
"cutlass"
)
if
USE_CUTLASS_BACKEND_FOR_FP4_GEMM
else
dict
(
)),
**
(
dict
(
backend
=
backend
)),
)
if
bias
is
not
None
:
out
=
out
+
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment