Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
915140fd
Unverified
Commit
915140fd
authored
Aug 04, 2025
by
azhurkevich
Committed by
GitHub
Aug 04, 2025
Browse files
[NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552)
Co-authored-by:
Cheng Wan
<
cwan@x.ai
>
parent
36fc9260
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
115 deletions
+502
-115
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+18
-7
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+173
-16
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+16
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+260
-63
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+25
-24
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+2
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
915140fd
...
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
...
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd
,
silu_and_mul_masked_post_quant_fwd
,
tma_align_input_scale
,
tma_align_input_scale
,
)
)
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFusedMoE
,
FusedMoE
FlashInferFusedMoE
,
FusedMoE
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.utils
import
DeepEPMode
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
(
from
sglang.srt.layers.quantization.fp8
import
(
...
@@ -48,7 +44,6 @@ _is_npu = is_npu()
...
@@ -48,7 +44,6 @@ _is_npu = is_npu()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
(
_is_npu
or
_is_hip
):
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
from
sgl_kernel
import
silu_and_mul
...
@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
...
@@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE):
def
get_moe_impl_class
():
def
get_moe_impl_class
():
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
if
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
DeepEPMoE
return
DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if
global_server_args_dict
.
get
(
"enable_flashinfer_trtllm_moe"
,
False
):
try
:
# Check the quantization argument directly
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
if
quantization
==
"modelopt_fp4"
:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
(
FlashInferFP4MoE
,
)
return
FlashInferFP4MoE
except
:
pass
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]:
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]:
return
FusedMoE
return
FusedMoE
if
get_moe_expert_parallel_world_size
()
>
1
:
if
get_moe_expert_parallel_world_size
()
>
1
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
915140fd
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import
importlib.util
import
datetime
import
glob
import
logging
import
logging
import
os
import
sys
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
packaging
import
version
as
pkg_version
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_rank
,
get_moe_expert_parallel_rank
,
...
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
...
@@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
)
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.topk
import
StandardTopKOutput
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
...
@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.utils
import
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_hip
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_flashinfer_available
,
is_hip
,
next_power_of_2
,
)
if
is_flashinfer_available
():
from
flashinfer
import
(
RoutingMethodType
,
fp4_quantize
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe
=
None
if
should_use_flashinfer_trtllm_moe
():
try
:
from
flashinfer.fused_moe
import
trtllm_fp4_block_scale_moe
except
ImportError
:
trtllm_fp4_block_scale_moe
=
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
lru_cache
(
maxsize
=
1
)
def
_is_fp4_quantization_enabled
():
def
should_use_flashinfer_trtllm_moe
():
"""Check if ModelOpt FP4 quantization is enabled."""
return
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
try
:
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
# Use the same simple check that works for class selection
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
quantization
=
global_server_args_dict
.
get
(
"quantization"
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
return
quantization
==
"modelopt_fp4"
)
except
:
return
False
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
class
FusedMoeWeightScaleSupported
(
Enum
):
class
FusedMoeWeightScaleSupported
(
Enum
):
...
@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module):
)
)
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
if
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
:
self
.
quant_method
.
enable_flashinfer_cutlass_moe
=
(
self
.
enable_flashinfer_cutlass_moe
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
...
@@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
)
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
moe_
tp_size
>
1
or
self
.
moe_
ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
class
FlashInferFP4MoE
(
FusedMoE
):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Extract DeepSeek-specific parameters
renormalize
=
kwargs
.
pop
(
"renormalize"
,
True
)
num_fused_shared_experts
=
kwargs
.
pop
(
"num_fused_shared_experts"
,
0
)
use_grouped_topk
=
kwargs
.
pop
(
"use_grouped_topk"
,
False
)
num_expert_group
=
kwargs
.
pop
(
"num_expert_group"
,
None
)
topk_group
=
kwargs
.
pop
(
"topk_group"
,
None
)
correction_bias
=
kwargs
.
pop
(
"correction_bias"
,
None
)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor
=
kwargs
.
pop
(
"routed_scaling_factor"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
# Store DeepSeek parameters
self
.
renormalize
=
renormalize
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
use_grouped_topk
=
use_grouped_topk
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
self
.
routed_scaling_factor
=
routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
def
_quantize_hidden_states_fp4
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Quantize hidden states using global scale factor from quantization method.
Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading.
Only block scales are computed at runtime for efficiency.
Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32)
"""
# flashinfer.fp4_quantize returns (packed_uint8, scale_fp8)
# Only the block scales are computed at runtime
hs_fp4_bytes
,
hs_sf_bytes
=
fp4_quantize
(
hidden_states
,
self
.
w13_input_scale_quant
,
16
,
# sf_vec_size
False
,
# use_ue8m0
False
,
# is_sf_swizzled_layout
)
hs_fp4
=
hs_fp4_bytes
.
reshape
(
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
//
2
)
hs_sf
=
hs_sf_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
return
hs_fp4
,
hs_sf
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
"""
# TRTLLM mode expects (TopK_config, router_logits) tuple
if
not
isinstance
(
topk_output
,
tuple
)
or
len
(
topk_output
)
!=
2
:
raise
ValueError
(
f
"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got
{
type
(
topk_output
)
}
"
)
_
,
router_logits
=
topk_output
hs_fp4
,
hs_scale_linear
=
self
.
_quantize_hidden_states_fp4
(
hidden_states
)
router_logits
=
router_logits
.
to
(
torch
.
float32
)
result
=
trtllm_fp4_block_scale_moe
(
routing_logits
=
router_logits
,
routing_bias
=
self
.
correction_bias
.
to
(
hidden_states
.
dtype
),
hidden_states
=
hs_fp4
,
hidden_states_scale
=
hs_scale_linear
.
view
(
torch
.
float8_e4m3fn
).
flatten
(),
gemm1_weights
=
self
.
gemm1_weights_fp4_shuffled
.
data
,
gemm1_weights_scale
=
self
.
gemm1_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
gemm2_weights
=
self
.
gemm2_weights_fp4_shuffled
.
data
,
gemm2_weights_scale
=
self
.
gemm2_scales_fp4_shuffled
.
data
.
view
(
torch
.
float8_e4m3fn
),
output1_scale_scalar
=
self
.
g1_scale_c
.
data
,
output1_scale_gate_scalar
=
self
.
g1_alphas
.
data
,
output2_scale_scalar
=
self
.
g2_alphas
.
data
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
n_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
local_expert_offset
=
self
.
moe_ep_rank
*
self
.
num_local_experts
,
local_num_experts
=
self
.
num_local_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
tile_tokens_dim
=
_get_tile_tokens_dim
(
hidden_states
.
shape
[
0
],
self
.
top_k
,
self
.
num_local_experts
),
routing_method_type
=
RoutingMethodType
.
DeepSeekV3
,
do_finalize
=
True
,
)[
0
]
return
result
def
get_fused_moe_impl_class
():
"""Factory function to get the appropriate FusedMoE implementation class."""
if
should_use_flashinfer_trtllm_moe
()
and
_is_fp4_quantization_enabled
():
# Use FP4 variant when FP4 quantization is enabled
return
FlashInferFP4MoE
elif
should_use_flashinfer_trtllm_moe
():
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
return
FlashInferFusedMoE
else
:
# Default case
return
FusedMoE
python/sglang/srt/layers/moe/utils.py
View file @
915140fd
import
importlib.util
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
lru_cache
from
packaging
import
version
as
pkg_version
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_trtllm_moe
():
result
=
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
(
not
importlib
.
util
.
find_spec
(
"flashinfer"
)
or
pkg_version
.
parse
(
__import__
(
"flashinfer"
).
__version__
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
return
result
class
MoeA2ABackend
(
Enum
):
class
MoeA2ABackend
(
Enum
):
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
915140fd
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from
__future__
import
annotations
from
__future__
import
annotations
import
importlib.util
import
logging
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
is_cuda
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
next_power_of_2
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -39,6 +42,11 @@ if is_cuda():
...
@@ -39,6 +42,11 @@ if is_cuda():
try
:
try
:
from
flashinfer
import
mm_fp4
as
fp4_gemm
from
flashinfer
import
mm_fp4
as
fp4_gemm
from
flashinfer
import
(
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
enable_flashinfer_fp4_gemm
=
True
enable_flashinfer_fp4_gemm
=
True
except
ImportError
:
except
ImportError
:
...
@@ -47,6 +55,9 @@ except ImportError:
...
@@ -47,6 +55,9 @@ except ImportError:
else
:
else
:
fp4_gemm
=
None
fp4_gemm
=
None
enable_flashinfer_fp4_gemm
=
False
enable_flashinfer_fp4_gemm
=
False
reorder_rows_for_gated_act_gemm
=
None
shuffle_matrix_a
=
None
shuffle_matrix_sf_a
=
None
try
:
try
:
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
...
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
)
->
Optional
[
QuantizeMethodBase
]:
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFP4MoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
...
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
return
ModelOptFp4LinearMethod
(
self
)
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FlashInferFP4MoE
):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return
ModelOptNvFp4FusedMoEMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
None
return
None
...
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" quantization. Please use Blackwell and"
" above."
" above."
)
)
self
.
enable_flashinfer_cutlass_moe
=
False
self
.
enable_flashinfer_trtllm_moe
=
should_use_flashinfer_trtllm_moe
()
@
property
def
enable_flashinfer_cutlass_moe
(
self
)
->
bool
:
"""Access the global enable_flashinfer_cutlass_moe setting."""
return
global_server_args_dict
.
get
(
"enable_flashinfer_cutlass_moe"
,
False
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" dynamic quantization is not supported."
" dynamic quantization is not supported."
)
)
# TODO(ch-wan): check if this is needed
layer
.
num_experts
=
num_experts
layer
.
num_experts
=
num_experts
layer
.
num_local_experts
=
num_experts
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
params_dtype
=
params_dtype
layer
.
params_dtype
=
params_dtype
layer
.
quant_config
=
self
.
quant_config
layer
.
quant_config
=
self
.
quant_config
weight_dtype
=
torch
.
uint8
weight_dtype
=
torch
.
uint8
weight_scale_dtype
=
torch
.
float8_e4m3fn
weight_scale_dtype
=
torch
.
float8_e4m3fn
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# GEMM 1
# GEMM 1
w13_weight
=
ModelWeightParameter
(
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num_experts
,
layer
.
local_
num_experts
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
# 2 fp4 items are packed in the input dimension
hidden_size
//
2
,
hidden_size
//
2
,
...
@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# GEMM 2
# GEMM 2
w2_weight
=
ModelWeightParameter
(
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
2
,
intermediate_size_per_partition
//
2
,
...
@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
ModelWeightParameter
(
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
# 2 fp4 items are packed in the input dimension
hidden_size
//
self
.
quant_config
.
group_size
,
hidden_size
//
self
.
quant_config
.
group_size
,
...
@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w2_weight_scale
=
ModelWeightParameter
(
w2_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
num
_experts
,
layer
.
num_local
_experts
,
hidden_size
,
hidden_size
,
# 2 fp4 items are packed in the input dimension
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
,
...
@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
)
w13_weight_scale_2
=
PerTensorScaleParameter
(
w13_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
w2_weight_scale_2
=
PerTensorScaleParameter
(
w2_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
layer
.
register_parameter
(
"w2_weight_scale_2"
,
w2_weight_scale_2
)
...
@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
)
w13_input_scale
=
PerTensorScaleParameter
(
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
2
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
PerTensorScaleParameter
(
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num
_experts
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
layer
.
num_local
_experts
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
swizzle_blockscale
(
self
,
scale
:
torch
.
t
ensor
):
def
swizzle_blockscale
(
self
,
scale
:
torch
.
T
ensor
):
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
# Pad and blockwise interleave weight_scale
# Pad and blockwise interleave weight_scale
scale_ndim
=
scale
.
ndim
scale_ndim
=
scale
.
ndim
...
@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
else
swizzled_scale
.
reshape
(
B
,
M
,
K
)
else
swizzled_scale
.
reshape
(
B
,
M
,
K
)
)
)
def
prepare_static_weights_for_kernel
(
self
,
# args_dequant,
# args,
gemm1_weights
,
gemm2_weights
,
gemm1_scales_linear_fp4_bytes
,
gemm2_scales_linear_fp4_bytes
,
hidden_size
,
intermediate_size
,
num_experts
,
):
from
flashinfer
import
(
RoutingMethodType
,
e2m1_and_ufp8sf_scale_to_float
,
fp4_quantize
,
next_positive_power_of_2
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
# Convert quantized weights to proper formats
gemm1_weights_fp4
=
gemm1_weights
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
# packed fp4
gemm1_scales_linear_fp4
=
gemm1_scales_linear_fp4_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
# fp8 scaling factors
gemm2_weights_fp4
=
gemm2_weights
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
2
)
# packed fp4
gemm2_scales_linear_fp4
=
gemm2_scales_linear_fp4_bytes
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
16
)
# fp8 scaling factors
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved
=
[]
gemm1_scales_fp4_interleaved
=
[]
for
i
in
range
(
num_experts
):
gemm1_weights_fp4_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
gemm1_weights_fp4
[
i
].
clone
())
)
gemm1_scales_fp4_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
gemm1_scales_linear_fp4
[
i
].
clone
())
)
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved
=
torch
.
stack
(
gemm1_weights_fp4_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
gemm1_scales_fp4_interleaved
=
torch
.
stack
(
gemm1_scales_fp4_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled
=
[]
gemm1_scales_fp4_shuffled
=
[]
gemm2_weights_fp4_shuffled
=
[]
gemm2_scales_fp4_shuffled
=
[]
for
i
in
range
(
num_experts
):
gemm1_weights_fp4_shuffled
.
append
(
shuffle_matrix_a
(
gemm1_weights_fp4_interleaved
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_scales_fp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
gemm1_scales_fp4_interleaved
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_weights_fp4_shuffled
.
append
(
shuffle_matrix_a
(
gemm2_weights_fp4
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_scales_fp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
gemm2_scales_linear_fp4
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
# Stack weights for all experts
gemm1_weights_fp4_shuffled
=
torch
.
stack
(
gemm1_weights_fp4_shuffled
)
gemm1_scales_fp4_shuffled
=
(
torch
.
stack
(
gemm1_scales_fp4_shuffled
)
.
view
(
torch
.
float8_e4m3fn
)
.
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
16
)
)
gemm2_weights_fp4_shuffled
=
torch
.
stack
(
gemm2_weights_fp4_shuffled
)
gemm2_scales_fp4_shuffled
=
(
torch
.
stack
(
gemm2_scales_fp4_shuffled
)
.
view
(
torch
.
float8_e4m3fn
)
.
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
16
)
)
return
(
gemm1_weights_fp4_shuffled
,
gemm1_scales_fp4_shuffled
,
gemm2_weights_fp4_shuffled
,
gemm2_scales_fp4_shuffled
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP4 MoE weights after loading from serialized checkpoint.
# GEMM 1
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
# GEMM 1 scale processing
if
not
torch
.
allclose
(
if
not
torch
.
allclose
(
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]
):
):
...
@@ -880,65 +1020,115 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -880,65 +1020,115 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
if
self
.
enable_flashinfer_cutlass_moe
:
# Calculate input scales based on strategy
if
self
.
enable_flashinfer_cutlass_moe
or
self
.
enable_flashinfer_trtllm_moe
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
().
to
(
torch
.
float32
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
().
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
else
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
# Create shared parameters
layer
.
g1_alphas
=
Parameter
(
layer
.
g1_alphas
=
Parameter
(
(
w13_input_scale
*
w13_weight_scale_2
).
to
(
torch
.
float32
),
(
w13_input_scale
*
w13_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
g2_alphas
=
Parameter
(
(
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
w13_input_scale_quant
=
Parameter
(
(
1
/
w13_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
layer
.
w2_input_scale_quant
=
Parameter
(
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
# Validate weight scales
for
name
,
weight_scale
in
[
(
"w13"
,
layer
.
w13_weight_scale
),
(
"w2"
,
layer
.
w2_weight_scale
),
]:
assert
(
assert
(
layer
.
w13_
weight_scale
.
shape
[
2
]
%
16
==
0
weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(
1
) to be divisible by 16"
),
f
"Expected
{
name
}
_
weight_scale.dim(
2
) to be divisible by 16"
assert
(
assert
(
layer
.
w13_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
),
f
"
{
name
}
Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
# Weight processing based on strategy
if
(
self
.
enable_flashinfer_trtllm_moe
and
reorder_rows_for_gated_act_gemm
is
not
None
and
shuffle_matrix_sf_a
is
not
None
):
# FlashInfer TRTLLM processing - handles both w13 and w2
(
gemm1_weights_fp4_shuffled
,
gemm1_scales_fp4_shuffled
,
gemm2_weights_fp4_shuffled
,
gemm2_scales_fp4_shuffled
,
)
=
self
.
prepare_static_weights_for_kernel
(
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight
.
size
(
-
2
),
# hidden_size
layer
.
w13_weight
.
size
(
-
2
)
//
2
,
# intermediate_size
layer
.
w13_weight
.
size
(
0
),
# num_experts
)
layer
.
w13_blockscale_swizzled
=
Parameter
(
# Set flashinfer parameters
w13_blockscale_swizzled
,
requires_grad
=
False
layer
.
gemm1_weights_fp4_shuffled
=
Parameter
(
gemm1_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_weights_fp4_shuffled
=
Parameter
(
gemm2_weights_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm1_scales_fp4_shuffled
=
Parameter
(
gemm1_scales_fp4_shuffled
,
requires_grad
=
False
)
layer
.
gemm2_scales_fp4_shuffled
=
Parameter
(
gemm2_scales_fp4_shuffled
,
requires_grad
=
False
)
)
del
layer
.
w13_weight_scale
# This is for quantization, so we need to invert it.
# Additional parameter needed for TRT-LLM
layer
.
w13_input_scale_quant
=
Parameter
(
layer
.
g1_scale_c
=
Parameter
(
(
1
/
w13_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
(
layer
.
w2_input_scale_quant
*
layer
.
g1_alphas
).
to
(
torch
.
float32
),
requires_grad
=
False
,
)
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
# Clean up weights that won't be used by TRT-LLM
del
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
)
# GEMM 2
print
(
"Applied flashinfer weight processing for both w13 and w2"
)
if
self
.
enable_flashinfer_cutlass_moe
:
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
w2_input_scale
=
layer
.
w2_input_scale
layer
.
g2_alphas
=
Parameter
(
else
:
(
w2_input_scale
*
layer
.
w2_weight_scale_2
).
to
(
torch
.
float32
),
# CUTLASS processing - handle w13 and w2 separately
requires_grad
=
False
,
)
# This is for quantization, so we need to invert it.
# Process w13 weights
layer
.
w2_input_scale_quant
=
Parameter
(
w13_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w13_weight_scale
)
(
1
/
w2_input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
layer
.
w13_blockscale_swizzled
=
Parameter
(
w13_blockscale_swizzled
,
requires_grad
=
False
)
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
assert
(
# Process w2 weights
layer
.
w2_weight_scale
.
shape
[
2
]
%
16
==
0
),
"Expected weight_scale.dim(1) to be divisible by 16"
assert
(
layer
.
w2_weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
"Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
w2_blockscale_swizzled
=
self
.
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_blockscale_swizzled
=
Parameter
(
layer
.
w2_blockscale_swizzled
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
w2_blockscale_swizzled
,
requires_grad
=
False
)
)
del
layer
.
w2_weight_scale
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print
(
"Applied weight processing for both w13 and w2"
)
# Set up CUTLASS MoE parameters
device
=
layer
.
w13_weight
.
device
device
=
layer
.
w13_weight
.
device
layer
.
cutlass_moe_params
=
CutlassMoEParams
(
layer
.
cutlass_moe_params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
CutlassMoEType
.
BlockscaledFP4
,
...
@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if
hasattr
(
layer
,
"gemm1_weights_fp4_shuffled"
):
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
return
layer
.
forward
(
x
,
topk_output
)
if
self
.
enable_flashinfer_cutlass_moe
:
if
self
.
enable_flashinfer_cutlass_moe
:
assert
(
assert
(
not
apply_router_weight_on_input
not
apply_router_weight_on_input
),
"apply_router_weight_on_input is not supported for Flashinfer"
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
# and fp4 quantized weights loaded from the checkpoint
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output
=
flashinfer_cutlass_fused_moe
(
output
=
flashinfer_cutlass_fused_moe
(
x
,
x
,
topk_ids
.
to
(
torch
.
int
),
topk_ids
.
to
(
torch
.
int
),
...
@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output
=
cutlass_moe_fp4
(
output
=
cutlass_moe_fp4
(
a
=
x
,
a
=
x
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
a1_gscale
=
layer
.
w13_input_scale_quant
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
915140fd
...
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
...
@@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin
,
ScheduleBatchDisaggregationDecodeMixin
,
)
)
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.moe.utils
import
DeepEPMode
,
MoeA2ABackend
from
sglang.srt.mem_cache.allocator
import
(
from
sglang.srt.mem_cache.allocator
import
(
BaseTokenToKVPoolAllocator
,
BaseTokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
SWATokenToKVPoolAllocator
,
...
@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_triton_kernel_moe"
,
"enable_triton_kernel_moe"
,
"enable_multimodal"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"enable_symm_mem"
,
"quantization"
,
]
]
# Put some global args for easy access
# Put some global args for easy access
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
915140fd
...
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
...
@@ -60,12 +60,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
DeepEPMoE
,
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
...
@@ -307,8 +304,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -307,8 +304,7 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
)
self
.
topk
=
(
self
.
topk
=
TopK
(
TopK
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
renormalize
=
config
.
norm_topk_prob
,
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
...
@@ -318,9 +314,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -318,9 +314,6 @@ class DeepseekV2MoE(nn.Module):
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
)
if
not
should_use_flashinfer_trtllm_moe
()
else
None
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
()(
num_experts
=
config
.
n_routed_experts
num_experts
=
config
.
n_routed_experts
...
@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
...
@@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
...
@@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
kwargs
=
{
"hidden_states"
:
hidden_states
}
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if
should_use_flashinfer_trtllm_moe
():
kwargs
[
"topk_output"
]
=
(
self
.
topk
,
router_logits
)
else
:
else
:
kwargs
[
"router_logits"
]
=
router_logits
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
...
...
python/sglang/srt/models/glm4_moe.py
View file @
915140fd
...
@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
...
@@ -50,11 +50,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
(
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
get_moe_impl_class
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.utils
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
is_fp8_fnuz
,
...
...
python/sglang/srt/server_args.py
View file @
915140fd
...
@@ -481,6 +481,13 @@ class ServerArgs:
...
@@ -481,6 +481,13 @@ class ServerArgs:
self
.
tp_size
,
self
.
tp_size
,
],
"The expert parallel size must be 1 or the same as the tensor parallel size"
],
"The expert parallel size must be 1 or the same as the tensor parallel size"
if
self
.
enable_flashinfer_trtllm_moe
:
if
not
self
.
disable_shared_experts_fusion
:
self
.
disable_shared_experts_fusion
=
True
logger
.
warning
(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE
# DeepEP MoE
if
self
.
moe_a2a_backend
==
"deepep"
:
if
self
.
moe_a2a_backend
==
"deepep"
:
if
self
.
deepep_mode
==
"normal"
:
if
self
.
deepep_mode
==
"normal"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment