Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
915140fd
Unverified
Commit
915140fd
authored
Aug 04, 2025
by
azhurkevich
Committed by
GitHub
Aug 04, 2025
Browse files
[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)
Co-authored-by:
Cheng Wan
<
cwan@x.ai
>
parent
36fc9260
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
115 deletions
+502
-115
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+18
-7
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+173
-16
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+16
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+260
-63
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+25
-24
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+2
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
915140fd
...
...
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFusedMoE
,
FusedMoE
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
(
...
...
@@ -48,7 +44,6 @@ _is_npu = is_npu()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
...
...
@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
def
get_moe_impl_class
():
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if
global_server_args_dict
.
get
(
"enable_flashinfer_trtllm_moe"
,
False
):
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
if
quantization
==
"modelopt_fp4"
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFP4MoE
,
)
return
FlashInferFP4MoE
except
:
pass
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]:
return
FusedMoE
if
get_moe_expert_parallel_world_size
()
>
1
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
915140fd
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import
importlib.util
import
datetime
import
glob
import
logging
import
os
import
sys
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
packaging
import
version
as
pkg_version
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_rank
,
...
...
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
...
...
@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.utils
import
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_hip
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_flashinfer_available
,
is_hip
,
next_power_of_2
,
)
if
is_flashinfer_available
():
from
flashinfer
import
(
RoutingMethodType
,
fp4_quantize
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
_is_hip
=
is_hip
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe
=
None
if
should_use_flashinfer_trtllm_moe
():
try
:
from
flashinfer.fused_moe
import
trtllm_fp4_block_scale_moe
except
ImportError
:
trtllm_fp4_block_scale_moe
=
None
logger
=
logging
.
getLogger
(
__name__
)
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_trtllm_moe
():
return
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
def
_is_fp4_quantization_enabled
():
"""Check if ModelOpt FP4 quantization is enabled."""
try
:
# Use the same simple check that works for class selection
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
return
quantization
==
"modelopt_fp4"
except
:
return
False
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
class
FusedMoeWeightScaleSupported
(
Enum
):
...
...
@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
if
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
:
self
.
quant_method
.
enable_flashinfer_cutlass_moe
=
(
self
.
enable_flashinfer_cutlass_moe
)
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
...
...
@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
moe_
tp_size
>
1
or
self
.
moe_
ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
class
FlashInferFP4MoE
(
FusedMoE
):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Extract DeepSeek-specific parameters
renormalize
=
kwargs
.
pop
(
"renormalize"
,
True
)
num_fused_shared_experts
=
kwargs
.
pop
(
"num_fused_shared_experts"
,
0
)
use_grouped_topk
=
kwargs
.
pop
(
"use_grouped_topk"
,
False
)
num_expert_group
=
kwargs
.
pop
(
"num_expert_group"
,
None
)
topk_group
=
kwargs
.
pop
(
"topk_group"
,
None
)
correction_bias
=
kwargs
.
pop
(
"correction_bias"
,
None
)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor
=
kwargs
.
pop
(
"routed_scaling_factor"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
# Store DeepSeek parameters
self
.
renormalize
=
renormalize
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
use_grouped_topk
=
use_grouped_topk
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
routed_scaling_factor
=
routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def
_quantize_hidden_states_fp4
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes
,
hs_sf_bytes
=
fp4_quantize
(
hidden_states
,
self
.
w13_input_scale_quant
,
16
,
# sf_vec_size
False
,
# use_ue8m0
False
,
# is_sf_swizzled_layout
)
hs_fp4
=
hs_fp4_bytes
.
reshape
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
//
2
)
hs_sf
=
hs_sf_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
return
hs_fp4
,
hs_sf
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if
not
isinstance
(
topk_output
,
tuple
)
or
len
(
topk_output
)
!=
2
:
raise
ValueError
(
f
"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got
{
type
(
topk_output
)
}
"
)
_
,
router_logits
=
topk_output
hs_fp4
,
hs_scale_linear
=
self
.
_quantize_hidden_states_fp4
(
hidden_states
)
router_logits
=
router_logits
.
to
(
torch
.
float32
)
result
=
trtllm_fp4_block_scale_moe
(
routing_logits
=
router_logits
,
routing_bias
=
self
.
correction_bias
.
to
(
hidden_states
.
dtype
),
hidden_states
=
hs_fp4
,
hidden_states_scale
=
hs_scale_linear
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
self
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
self
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
self
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
self
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
output1_scale_scalar
=
self
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
self
.
g1_alphas
.
data
,
output2_scale_scalar
=
self
.
g2_alphas
.
data
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
n_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
local_expert_offset
=
self
.
moe_ep_rank
*
self
.
num_local_experts
,
local_num_experts
=
self
.
num_local_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
tile_tokens_dim
=
_get_tile_tokens_dim
(
hidden_states
.
shape
[
0
],
self
.
top_k
,
self
.
num_local_experts
),
routing_method_type
=
RoutingMethodType
.
DeepSeekV3
,
do_finalize
=
True
,
)[
0
]
return
result
def
get_fused_moe_impl_class
():
"""Factory function to get the appropriate FusedMoE implementation class."""
if
should_use_flashinfer_trtllm_moe
()
and
_is_fp4_quantization_enabled
():
# Use FP4 variant when FP4 quantization is enabled
return
FlashInferFP4MoE
elif
should_use_flashinfer_trtllm_moe
():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return
FlashInferFusedMoE
else
:
# Default case
return
FusedMoE
python/sglang/srt/layers/moe/utils.py
View file @
915140fd
import
importlib.util
from
enum
import
Enum
from
functools
import
lru_cache
from
packaging
import
version
as
pkg_version
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_trtllm_moe
():
result
=
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
return
result
class
MoeA2ABackend
(
Enum
):
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
915140fd
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from
__future__
import
annotations
import
importlib.util
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
...
...
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
is_cuda
,
next_power_of_2
if
TYPE_CHECKING
:
...
...
@@ -39,6 +42,11 @@ if is_cuda():
try
:
from
flashinfer
import
mm_fp4
as
fp4_gemm
from
flashinfer
import
(
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
enable_flashinfer_fp4_gemm
=
True
except
ImportError
:
...
...
@@ -47,6 +55,9 @@ except ImportError:
else
:
fp4_gemm
=
None
enable_flashinfer_fp4_gemm
=
False
reorder_rows_for_gated_act_gemm
=
None
shuffle_matrix_a
=
None
shuffle_matrix_sf_a
=
None
try
:
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
...
...
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFP4MoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
...
...
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FlashInferFP4MoE
):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return
ModelOptNvFp4FusedMoEMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
None
...
...
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above."
)
self
.
enable_flashinfer_cutlass_moe
=
False
self
.
enable_flashinfer_trtllm_moe
=
should_use_flashinfer_trtllm_moe
()
@
property
def
enable_flashinfer_cutlass_moe
(
self
)
->
bool
:
"""Access the global enable_flashinfer_cutlass_moe setting."""
return
global_server_args_dict
.
get
(
"enable_flashinfer_cutlass_moe"
,
False
)
def
create_weights
(
self
,
...
...
@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" dynamic quantization is not supported."
)
# TODO(ch-wan): check if this is needed
layer
.
num_experts
=
num_experts
layer
.
num_local_experts
=
num_experts
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
params_dtype
=
params_dtype
layer
.
quant_config
=
self
.
quant_config
weight_dtype
=
torch
.
uint8
weight_scale_dtype
=
torch
.
float8_e4m3fn
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# GEMM 1
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
layer
.
local_
num_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
2
,
...
...
@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# GEMM 2
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
2
,
...
...
@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
self
.
quant_config
.
group_size
,
...
...
@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
...
...
@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
w13_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
w2_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
...
...
@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
t
ensor
):
def
swizzle_blockscale
(
self
,
scale
:
torch
.
T
ensor
):
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
...
...
@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
else
swizzled_scale
.
reshape
(
B
,
M
,
K
)
)
def
prepare_static_weights_for_kernel
(
self
,
# args_dequant,
# args,
gemm1_weights
,
gemm2_weights
,
gemm1_scales_linear_fp4_bytes
,
gemm2_scales_linear_fp4_bytes
,
hidden_size
,
intermediate_size
,
num_experts
,
):
from
flashinfer
import
(
RoutingMethodType
,
e2m1_and_ufp8sf_scale_to_float
,
fp4_quantize
,
next_positive_power_of_2
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4
=
gemm1_weights
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
# packed fp4
gemm1_scales_linear_fp4
=
gemm1_scales_linear_fp4_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
# fp8 scaling factors
gemm2_weights_fp4
=
gemm2_weights
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
2
)
# packed fp4
gemm2_scales_linear_fp4
=
gemm2_scales_linear_fp4_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
16
)
# fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved
=
[]
gemm1_scales_fp4_interleaved
=
[]
for
i
in
range
(
num_experts
):
gemm1_weights_fp4_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
gemm1_weights_fp4
[
i
].
clone
())
)
gemm1_scales_fp4_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
gemm1_scales_linear_fp4
[
i
].
clone
())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved
=
torch
.
stack
(
gemm1_weights_fp4_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
gemm1_scales_fp4_interleaved
=
torch
.
stack
(
gemm1_scales_fp4_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled
=
[]
gemm1_scales_fp4_shuffled
=
[]
gemm2_weights_fp4_shuffled
=
[]
gemm2_scales_fp4_shuffled
=
[]
for
i
in
range
(
num_experts
):
gemm1_weights_fp4_shuffled
.
append
(
shuffle_matrix_a
(
gemm1_weights_fp4_interleaved
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_scales_fp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
gemm1_scales_fp4_interleaved
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_weights_fp4_shuffled
.
append
(
shuffle_matrix_a
(
gemm2_weights_fp4
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_scales_fp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
gemm2_scales_linear_fp4
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled
=
torch
.
stack
(
gemm1_weights_fp4_shuffled
)
gemm1_scales_fp4_shuffled
=
(
torch
.
stack
(
gemm1_scales_fp4_shuffled
)
.
view
(
torch
.
float8_e4m3fn
)
.
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
)
gemm2_weights_fp4_shuffled
=
torch
.
stack
(
gemm2_weights_fp4_shuffled
)
gemm2_scales_fp4_shuffled
=
(
torch
.
stack
(
gemm2_scales_fp4_shuffled
)
.
view
(
torch
.
float8_e4m3fn
)
.
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
16
)
)
return
(
gemm1_weights_fp4_shuffled
,
gemm1_scales_fp4_shuffled
,
gemm2_weights_fp4_shuffled
,
gemm2_scales_fp4_shuffled
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP4 MoE weights after loading from serialized checkpoint.
# GEMM 1
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
# GEMM 1 scale processing
if
not
torch
.
allclose
(
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]
):
...
...
@@ -880,73 +1020,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
if
self
.
enable_flashinfer_cutlass_moe
:
# Calculate input scales based on strategy
if
self
.
enable_flashinfer_cutlass_moe
or
self
.
enable_flashinfer_trtllm_moe
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
().
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
# Create shared parameters
layer
.
g1_alphas
=
Parameter
(
(
w13_input_scale
*
w13_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
assert
(
layer
.
w13_weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(1) to be divisible by 16"
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
layer
.
g2_alphas
=
Parameter
(
(
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
del
layer
.
w13_weight_scale
# This is for quantization, so we need to invert it.
layer
.
w13_input_scale_quant
=
Parameter
(
(
1
/
w13_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
w2_input_scale_quant
=
Parameter
(
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
# Validate weight scales
for
name
,
weight_scale
in
[
(
"w13"
,
layer
.
w13_weight_scale
),
(
"w2"
,
layer
.
w2_weight_scale
),
]:
assert
(
weight_scale
.
shape
[
2
]
%
16
==
0
),
f
"Expected
{
name
}
_weight_scale.dim(2) to be divisible by 16"
assert
(
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
f
"
{
name
}
Weight Blockscale must be represented as FP8-E4M3"
# Weight processing based on strategy
if
(
self
.
enable_flashinfer_trtllm_moe
and
reorder_rows_for_gated_act_gemm
is
not
None
and
shuffle_matrix_sf_a
is
not
None
):
# FlashInfer TRTLLM processing - handles both w13 and w2
(
gemm1_weights_fp4_shuffled
,
gemm1_scales_fp4_shuffled
,
gemm2_weights_fp4_shuffled
,
gemm2_scales_fp4_shuffled
,
)
=
self
.
prepare_static_weights_for_kernel
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight
.
size
(
-
2
),
# hidden_size
layer
.
w13_weight
.
size
(
-
2
)
//
2
,
# intermediate_size
layer
.
w13_weight
.
size
(
0
),
# num_experts
)
# GEMM 2
if
self
.
enable_flashinfer_cutlass_moe
:
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
w2_input_scale
=
layer
.
w2_input_scale
# Set flashinfer parameters
layer
.
gemm1_weights_fp4_shuffled
=
Parameter
(
gemm1_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_weights_fp4_shuffled
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm1_scales_fp4_shuffled
=
Parameter
(
gemm1_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_scales_fp4_shuffled
=
Parameter
(
gemm2_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
g2_alphas
=
Parameter
(
(
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# Additional parameter needed for TRT-LLM
layer
.
g1_scale_c
=
Parameter
(
(
layer
.
w2_input_scale_quant
*
layer
.
g1_alphas
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
# This is for quantization, so we need to invert it.
layer
.
w2_input_scale_quant
=
Parameter
(
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
# Clean up weights that won't be used by TRT-LLM
del
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
)
assert
(
layer
.
w2_weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(1) to be divisible by 16"
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
print
(
"Applied flashinfer weight processing for both w13 and w2"
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
del
layer
.
w2_weight_scale
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
else
:
# CUTLASS processing - handle w13 and w2 separately
# Process w13 weights
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
# Process w2 weights
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print
(
"Applied weight processing for both w13 and w2"
)
device
=
layer
.
w13_weight
.
device
layer
.
cutlass_moe_params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
device
,
num_experts
=
layer
.
num_experts
,
# global num experts
intermediate_size_per_partition
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
# n
hidden_size
=
layer
.
w13_weight
.
shape
[
2
]
*
2
,
)
# k
# Set up CUTLASS MoE parameters
device
=
layer
.
w13_weight
.
device
layer
.
cutlass_moe_params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
device
,
num_experts
=
layer
.
num_experts
,
# global num experts
intermediate_size_per_partition
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
# n
hidden_size
=
layer
.
w13_weight
.
shape
[
2
]
*
2
,
)
# k
@
property
def
load_up_proj_weight_first
(
self
)
->
bool
:
...
...
@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return
layer
.
forward
(
x
,
topk_output
)
if
self
.
enable_flashinfer_cutlass_moe
:
assert
(
not
apply_router_weight_on_input
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output
=
flashinfer_cutlass_fused_moe
(
x
,
topk_ids
.
to
(
torch
.
int
),
...
...
@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output
=
cutlass_moe_fp4
(
a
=
x
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
915140fd
...
...
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
MoeA2ABackend
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
...
...
@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_triton_kernel_moe"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
]
# Put some global args for easy access
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
915140fd
...
...
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
DeepEPMoE
,
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
...
@@ -307,19 +304,15 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
self
.
topk
=
(
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
not
should_use_flashinfer_trtllm_moe
()
else
None
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
self
.
experts
=
get_moe_impl_class
()(
...
...
@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
...
...
python/sglang/srt/models/glm4_moe.py
View file @
915140fd
...
...
@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
...
...
python/sglang/srt/server_args.py
View file @
915140fd
...
...
@@ -481,6 +481,13 @@ class ServerArgs:
self
.
tp_size
,
],
"The expert parallel size must be 1 or the same as the tensor parallel size"
if
self
.
enable_flashinfer_trtllm_moe
:
if
not
self
.
disable_shared_experts_fusion
:
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE
if
self
.
moe_a2a_backend
==
"deepep"
:
if
self
.
deepep_mode
==
"normal"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment