Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
915140fd
Unverified
Commit
915140fd
authored
Aug 04, 2025
by
azhurkevich
Committed by
GitHub
Aug 04, 2025
Browse files
[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)
Co-authored-by:
Cheng Wan
<
cwan@x.ai
>
parent
36fc9260
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
115 deletions
+502
-115
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+18
-7
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+173
-16
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+16
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+260
-63
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+25
-24
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+2
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
915140fd
...
...
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFusedMoE
,
FusedMoE
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
(
...
...
@@ -48,7 +44,6 @@ _is_npu = is_npu()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
...
...
@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
def
get_moe_impl_class
():
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if
global_server_args_dict
.
get
(
"enable_flashinfer_trtllm_moe"
,
False
):
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
if
quantization
==
"modelopt_fp4"
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFP4MoE
,
)
return
FlashInferFP4MoE
except
:
pass
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]:
return
FusedMoE
if
get_moe_expert_parallel_world_size
()
>
1
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
915140fd
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import
importlib.util
import
datetime
import
glob
import
logging
import
os
import
sys
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
packaging
import
version
as
pkg_version
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_rank
,
...
...
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
...
...
@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.utils
import
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_hip
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_flashinfer_available
,
is_hip
,
next_power_of_2
,
)
if
is_flashinfer_available
():
from
flashinfer
import
(
RoutingMethodType
,
fp4_quantize
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
_is_hip
=
is_hip
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe
=
None
if
should_use_flashinfer_trtllm_moe
():
try
:
from
flashinfer.fused_moe
import
trtllm_fp4_block_scale_moe
except
ImportError
:
trtllm_fp4_block_scale_moe
=
None
logger
=
logging
.
getLogger
(
__name__
)
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_trtllm_moe
():
return
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
def
_is_fp4_quantization_enabled
():
"""Check if ModelOpt FP4 quantization is enabled."""
try
:
# Use the same simple check that works for class selection
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
return
quantization
==
"modelopt_fp4"
except
:
return
False
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
class
FusedMoeWeightScaleSupported
(
Enum
):
...
...
@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
if
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
:
self
.
quant_method
.
enable_flashinfer_cutlass_moe
=
(
self
.
enable_flashinfer_cutlass_moe
)
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
...
...
@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
moe_
tp_size
>
1
or
self
.
moe_
ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
class
FlashInferFP4MoE
(
FusedMoE
):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Extract DeepSeek-specific parameters
renormalize
=
kwargs
.
pop
(
"renormalize"
,
True
)
num_fused_shared_experts
=
kwargs
.
pop
(
"num_fused_shared_experts"
,
0
)
use_grouped_topk
=
kwargs
.
pop
(
"use_grouped_topk"
,
False
)
num_expert_group
=
kwargs
.
pop
(
"num_expert_group"
,
None
)
topk_group
=
kwargs
.
pop
(
"topk_group"
,
None
)
correction_bias
=
kwargs
.
pop
(
"correction_bias"
,
None
)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor
=
kwargs
.
pop
(
"routed_scaling_factor"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
# Store DeepSeek parameters
self
.
renormalize
=
renormalize
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
use_grouped_topk
=
use_grouped_topk
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
routed_scaling_factor
=
routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def
_quantize_hidden_states_fp4
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes
,
hs_sf_bytes
=
fp4_quantize
(
hidden_states
,
self
.
w13_input_scale_quant
,
16
,
# sf_vec_size
False
,
# use_ue8m0
False
,
# is_sf_swizzled_layout
)
hs_fp4
=
hs_fp4_bytes
.
reshape
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
//
2
)
hs_sf
=
hs_sf_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
return
hs_fp4
,
hs_sf
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if
not
isinstance
(
topk_output
,
tuple
)
or
len
(
topk_output
)
!=
2
:
raise
ValueError
(
f
"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got
{
type
(
topk_output
)
}
"
)
_
,
router_logits
=
topk_output
hs_fp4
,
hs_scale_linear
=
self
.
_quantize_hidden_states_fp4
(
hidden_states
)
router_logits
=
router_logits
.
to
(
torch
.
float32
)
result
=
trtllm_fp4_block_scale_moe
(
routing_logits
=
router_logits
,
routing_bias
=
self
.
correction_bias
.
to
(
hidden_states
.
dtype
),
hidden_states
=
hs_fp4
,
hidden_states_scale
=
hs_scale_linear
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
self
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
self
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
self
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
self
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
output1_scale_scalar
=
self
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
self
.
g1_alphas
.
data
,
output2_scale_scalar
=
self
.
g2_alphas
.
data
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
n_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
local_expert_offset
=
self
.
moe_ep_rank
*
self
.
num_local_experts
,
local_num_experts
=
self
.
num_local_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
tile_tokens_dim
=
_get_tile_tokens_dim
(
hidden_states
.
shape
[
0
],
self
.
top_k
,
self
.
num_local_experts
),
routing_method_type
=
RoutingMethodType
.
DeepSeekV3
,
do_finalize
=
True
,
)[
0
]
return
result
def
get_fused_moe_impl_class
():
"""Factory function to get the appropriate FusedMoE implementation class."""
if
should_use_flashinfer_trtllm_moe
()
and
_is_fp4_quantization_enabled
():
# Use FP4 variant when FP4 quantization is enabled
return
FlashInferFP4MoE
elif
should_use_flashinfer_trtllm_moe
():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return
FlashInferFusedMoE
else
:
# Default case
return
FusedMoE
python/sglang/srt/layers/moe/utils.py
View file @
915140fd
import
importlib.util
from
enum
import
Enum
from
functools
import
lru_cache
from
packaging
import
version
as
pkg_version
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_trtllm_moe
():
result
=
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
return
result
class
MoeA2ABackend
(
Enum
):
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
915140fd
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/schedule_batch.py
View file @
915140fd
...
...
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
MoeA2ABackend
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
...
...
@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_triton_kernel_moe"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
]
# Put some global args for easy access
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
915140fd
...
...
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
DeepEPMoE
,
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -307,19 +304,15 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
(
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
not
should_use_flashinfer_trtllm_moe
()
else
None
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
...
...
@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
...
...
python/sglang/srt/models/glm4_moe.py
View file @
915140fd
...
...
@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
...
...
python/sglang/srt/server_args.py
View file @
915140fd
...
...
@@ -481,6 +481,13 @@ class ServerArgs:
self
.
tp_size
,
],
"The expert parallel size must be 1 or the same as the tensor parallel size"
if
self
.
enable_flashinfer_trtllm_moe
:
if
not
self
.
disable_shared_experts_fusion
:
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE
if
self
.
moe_a2a_backend
==
"deepep"
:
if
self
.
deepep_mode
==
"normal"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment