Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
77225d60
Unverified
Commit
77225d60
authored
Oct 28, 2025
by
b8zhong
Committed by
GitHub
Oct 28, 2025
Browse files
Use Flashinfer TRT-LLM as Llama 4 compatible MoE backend (#11928)
parent
9c6e25d2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
171 additions
and
7 deletions
+171
-7
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+160
-4
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-2
No files found.
python/sglang/srt/configs/model_config.py
View file @
77225d60
...
...
@@ -643,6 +643,7 @@ class ModelConfig:
"petit_nvfp4"
,
]
compatible_quantization_methods
=
{
"modelopt_fp8"
:
[
"modelopt"
],
"modelopt_fp4"
:
[
"modelopt"
],
"petit_nvfp4"
:
[
"modelopt"
],
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
],
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
77225d60
...
...
@@ -25,6 +25,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
...
...
@@ -468,8 +469,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Shape: (num_experts, 2)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
...
...
@@ -517,6 +516,84 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if
should_use_flashinfer_trtllm_moe
():
from
flashinfer
import
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
# 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI
num_experts
,
two_n
,
hidden
=
layer
.
w13_weight
.
shape
inter
=
two_n
//
2
w13_swapped
=
(
layer
.
w13_weight
.
reshape
(
num_experts
,
2
,
inter
,
hidden
)
.
flip
(
dims
=
[
1
])
.
reshape
(
num_experts
,
two_n
,
hidden
)
)
# 2) Reorder rows for fused gated activation (W13)
w13_interleaved
=
[
reorder_rows_for_gated_act_gemm
(
w13_swapped
[
i
])
for
i
in
range
(
num_experts
)
]
w13_interleaved
=
torch
.
stack
(
w13_interleaved
).
reshape
(
num_experts
,
two_n
,
hidden
)
# 3) Shuffle weights for transposed MMA output (both W13, W2)
epilogue_tile_m
=
128
w13_shuffled
=
[
shuffle_matrix_a
(
w13_interleaved
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
for
i
in
range
(
num_experts
)
]
w2_shuffled
=
[
shuffle_matrix_a
(
layer
.
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
for
i
in
range
(
num_experts
)
]
layer
.
w13_weight
=
Parameter
(
torch
.
stack
(
w13_shuffled
).
view
(
torch
.
float8_e4m3fn
),
requires_grad
=
False
,
)
layer
.
w2_weight
=
Parameter
(
torch
.
stack
(
w2_shuffled
).
view
(
torch
.
float8_e4m3fn
),
requires_grad
=
False
,
)
# Precompute and register per-expert output scaling factors for FI MoE
if
should_use_flashinfer_trtllm_moe
():
# Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
assert
(
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
)
assert
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
assert
(
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
)
assert
(
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
)
input_scale
=
layer
.
w13_input_scale
.
to
(
torch
.
float32
)
activation_scale
=
layer
.
w2_input_scale
.
to
(
torch
.
float32
)
w13_weight_scale
=
layer
.
w13_weight_scale
.
to
(
torch
.
float32
)
w2_weight_scale
=
layer
.
w2_weight_scale
.
to
(
torch
.
float32
)
output1_scales_scalar
=
(
w13_weight_scale
*
input_scale
*
(
1.0
/
activation_scale
)
)
output1_scales_gate_scalar
=
w13_weight_scale
*
input_scale
output2_scales_scalar
=
activation_scale
*
w2_weight_scale
layer
.
output1_scales_scalar
=
Parameter
(
output1_scales_scalar
,
requires_grad
=
False
)
layer
.
output1_scales_gate_scalar
=
Parameter
(
output1_scales_gate_scalar
,
requires_grad
=
False
)
layer
.
output2_scales_scalar
=
Parameter
(
output2_scales_scalar
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
...
...
@@ -528,6 +605,81 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
# Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
if
should_use_flashinfer_trtllm_moe
()
and
TopKOutputChecker
.
format_is_bypassed
(
topk_output
):
router_logits
=
topk_output
.
router_logits
topk_config
=
topk_output
.
topk_config
# Constraints
assert
(
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only silu is supported for flashinfer fp8 moe"
from
flashinfer
import
RoutingMethodType
from
flashinfer.fused_moe
import
trtllm_fp8_per_tensor_scale_moe
correction_bias
=
(
None
if
topk_config
.
correction_bias
is
None
else
topk_config
.
correction_bias
)
# Pre-quantize activations to FP8 per-tensor using provided input scale
x_fp8
,
_
=
scaled_fp8_quant
(
x
,
layer
.
w13_input_scale
)
use_routing_scales_on_input
=
True
routed_scaling_factor
=
self
.
moe_runner_config
.
routed_scaling_factor
# Enforce Llama4 routing for ModelOpt FP8 MoE for now.
# TODO(brayden): support other routing methods
assert
topk_config
.
top_k
==
1
,
"ModelOpt FP8 MoE requires top_k==1"
assert
(
not
topk_config
.
num_expert_group
),
"ModelOpt FP8 MoE does not support expert grouping"
assert
(
not
topk_config
.
topk_group
),
"ModelOpt FP8 MoE does not support grouped top-k"
routing_method_type
=
RoutingMethodType
.
Llama4
# FlashInfer TRTLLM requires routing_logits (and bias) to be bfloat16
routing_logits_cast
=
router_logits
.
to
(
torch
.
bfloat16
)
routing_bias_cast
=
(
None
if
correction_bias
is
None
else
correction_bias
.
to
(
torch
.
bfloat16
)
)
output
=
trtllm_fp8_per_tensor_scale_moe
(
routing_logits
=
routing_logits_cast
,
routing_bias
=
routing_bias_cast
,
hidden_states
=
x_fp8
,
gemm1_weights
=
layer
.
w13_weight
,
output1_scales_scalar
=
layer
.
output1_scales_scalar
,
output1_scales_gate_scalar
=
layer
.
output1_scales_gate_scalar
,
gemm2_weights
=
layer
.
w2_weight
,
output2_scales_scalar
=
layer
.
output2_scales_scalar
,
num_experts
=
layer
.
num_experts
,
top_k
=
topk_config
.
top_k
,
n_group
=
0
,
topk_group
=
0
,
intermediate_size
=
layer
.
w2_weight
.
shape
[
2
],
local_expert_offset
=
layer
.
moe_ep_rank
*
layer
.
num_local_experts
,
local_num_experts
=
layer
.
num_local_experts
,
routed_scaling_factor
=
(
routed_scaling_factor
if
routed_scaling_factor
is
not
None
else
1.0
),
use_routing_scales_on_input
=
use_routing_scales_on_input
,
tile_tokens_dim
=
None
,
routing_method_type
=
routing_method_type
,
)
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
return
StandardCombineInput
(
hidden_states
=
output
)
quant_info
=
TritonMoeQuantInfo
(
w13_weight
=
layer
.
w13_weight
,
...
...
@@ -1384,8 +1536,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
alt_stream
=
None
,
)
->
CombineInput
:
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
...
...
@@ -1398,6 +1548,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
return
StandardCombineInput
(
hidden_states
=
layer
.
forward
(
x
,
topk_output
))
if
self
.
enable_flashinfer_cutlass_moe
:
...
...
@@ -1466,6 +1618,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if
forward_shared_experts
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
alt_stream
)
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
return
StandardCombineInput
(
hidden_states
=
output
)
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
...
...
@@ -1487,6 +1641,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input
=
moe_runner_config
.
apply_router_weight_on_input
,
).
to
(
x
.
dtype
)
# Scale by routed_scaling_factor is fused into select_experts.
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
return
StandardCombineInput
(
hidden_states
=
output
)
def
apply_without_routing_weights
(
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
77225d60
...
...
@@ -238,7 +238,7 @@ def get_quant_config(
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
.
startswith
(
"modelopt"
)
and
(
config
[
"producer"
][
"name"
]
.
startswith
(
"modelopt"
)
config
.
get
(
"producer"
,
{}).
get
(
"name"
,
""
)
.
startswith
(
"modelopt"
)
):
quant_algo
=
config
[
"quantization"
][
"quant_algo"
]
if
quant_algo
is
None
:
...
...
python/sglang/srt/server_args.py
View file @
77225d60
...
...
@@ -971,6 +971,11 @@ class ServerArgs:
logger
.
warning
(
"Use trtllm_mha as attention backend on sm100 for Llama4 model"
)
if
is_sm100_supported
()
and
self
.
moe_runner_backend
==
"auto"
:
self
.
moe_runner_backend
=
"flashinfer_trtllm"
logger
.
info
(
"Use flashinfer_trtllm as MoE runner backend on SM100 for Llama4"
)
elif
model_arch
in
[
"Gemma2ForCausalLM"
,
"Gemma3ForCausalLM"
,
...
...
@@ -1336,8 +1341,10 @@ class ServerArgs:
if
self
.
moe_runner_backend
==
"flashinfer_trtllm"
:
assert
(
self
.
quantization
==
"modelopt_fp4"
or
self
.
quantization
==
"fp8"
),
"modelopt_fp4 or fp8 quantization is required for Flashinfer TRTLLM MoE"
self
.
quantization
==
"modelopt_fp4"
or
self
.
quantization
==
"modelopt_fp8"
or
self
.
quantization
==
"fp8"
),
"modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE"
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment