Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
85486b6f
Unverified
Commit
85486b6f
authored
Jul 27, 2025
by
Kaixi Hou
Committed by
GitHub
Jul 27, 2025
Browse files
[NVIDIA] Add Flashinfer MoE blockscale fp8 backend (#8036)
parent
e34cf6ad
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
179 additions
and
47 deletions
+179
-47
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+102
-7
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+9
-7
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+5
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+44
-20
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+2
-2
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-3
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
85486b6f
...
@@ -47,12 +47,17 @@ from sglang.srt.utils import (
...
@@ -47,12 +47,17 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_bool_env_var
,
is_hip
,
is_hip
,
is_npu
,
is_npu
,
next_power_of_2
,
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
use_flashinfer_trtllm_moe
=
(
global_server_args_dict
[
"enable_flashinfer_trtllm_moe"
]
and
global_server_args_dict
[
"enable_ep_moe"
]
)
if
not
(
_is_npu
or
_is_hip
):
if
not
(
_is_npu
or
_is_hip
):
from
sgl_kernel
import
silu_and_mul
from
sgl_kernel
import
silu_and_mul
...
@@ -64,6 +69,13 @@ if _use_aiter:
...
@@ -64,6 +69,13 @@ if _use_aiter:
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
if
use_flashinfer_trtllm_moe
:
try
:
import
flashinfer.fused_moe
as
fi_fused_moe
except
ImportError
:
fi_fused_moe
=
None
use_flashinfer_trtllm_moe
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module):
...
@@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module):
return
c
return
c
def
_get_tile_tokens_dim
(
num_tokens
,
top_k
,
num_experts
):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
class
EPMoE
(
torch
.
nn
.
Module
):
class
EPMoE
(
torch
.
nn
.
Module
):
"""
"""
MoE Expert Parallel Impl
MoE Expert Parallel Impl
...
@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module):
...
@@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module):
)
)
return
return
if
shard_id
==
"w2"
:
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if
use_flashinfer_trtllm_moe
:
actual_shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
else
:
actual_shard_id
=
shard_id
if
actual_shard_id
==
"w2"
:
param
.
data
[
expert_id
]
=
loaded_weight
param
.
data
[
expert_id
]
=
loaded_weight
elif
shard_id
==
"w1"
:
elif
actual_
shard_id
==
"w1"
:
param
.
data
[
expert_id
][:
self
.
intermediate_size
,
:]
=
loaded_weight
param
.
data
[
expert_id
][:
self
.
intermediate_size
,
:]
=
loaded_weight
elif
shard_id
==
"w3"
:
elif
actual_
shard_id
==
"w3"
:
param
.
data
[
expert_id
][
self
.
intermediate_size
:,
:]
=
loaded_weight
param
.
data
[
expert_id
][
self
.
intermediate_size
:,
:]
=
loaded_weight
else
:
else
:
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
shard_id
}
"
)
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
actual_
shard_id
}
"
)
def
_load_fp8_scale
(
def
_load_fp8_scale
(
self
,
self
,
...
@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module):
...
@@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module):
# Weight scales
# Weight scales
elif
"weight_scale"
in
weight_name
:
elif
"weight_scale"
in
weight_name
:
if
self
.
use_block_quant
:
if
self
.
use_block_quant
:
if
use_flashinfer_trtllm_moe
:
actual_shard_id
=
{
"w1"
:
"w3"
,
"w3"
:
"w1"
,
"w2"
:
"w2"
}[
shard_id
]
else
:
actual_shard_id
=
shard_id
block_n
,
block_k
=
self
.
block_shape
[
0
],
self
.
block_shape
[
1
]
block_n
,
block_k
=
self
.
block_shape
[
0
],
self
.
block_shape
[
1
]
if
shard_id
==
"w1"
:
if
actual_shard_id
==
"w1"
:
param_data
[
expert_id
][
param_data
[
expert_id
][
:
(
self
.
intermediate_size
+
block_n
-
1
)
//
block_n
,
:
:
(
self
.
intermediate_size
+
block_n
-
1
)
//
block_n
,
:
]
=
loaded_weight
]
=
loaded_weight
elif
shard_id
==
"w3"
:
elif
actual_
shard_id
==
"w3"
:
param_data
[
expert_id
][
param_data
[
expert_id
][
(
self
.
intermediate_size
+
block_n
-
1
)
//
block_n
:,
:
(
self
.
intermediate_size
+
block_n
-
1
)
//
block_n
:,
:
]
=
loaded_weight
]
=
loaded_weight
...
@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE):
...
@@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE):
return
down_output
return
down_output
class
FlashInferEPMoE
(
EPMoE
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
renormalize
=
kwargs
.
pop
(
"renormalize"
,
True
)
num_fused_shared_experts
=
kwargs
.
pop
(
"num_fused_shared_experts"
,
0
)
use_grouped_topk
=
kwargs
.
pop
(
"use_grouped_topk"
,
False
)
num_expert_group
=
kwargs
.
pop
(
"num_expert_group"
,
None
)
topk_group
=
kwargs
.
pop
(
"topk_group"
,
None
)
correction_bias
=
kwargs
.
pop
(
"correction_bias"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
renormalize
=
renormalize
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
correction_bias
=
correction_bias
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
use_flashinfer_trtllm_moe
assert
(
self
.
activation
==
"silu"
),
"Only silu is supported for flashinfer blockscale fp8 moe"
assert
(
self
.
renormalize
),
"Renormalize is required for flashinfer blockscale fp8 moe"
assert
(
self
.
num_fused_shared_experts
==
0
),
"Fused shared experts are not supported for flashinfer blockscale fp8 moe"
a_q
,
a_sf
=
sglang_per_token_group_quant_fp8
(
hidden_states
,
self
.
block_shape
[
1
])
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
assert
fi_fused_moe
is
not
None
return
fi_fused_moe
.
trtllm_fp8_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
routing_bias
=
self
.
correction_bias
.
to
(
hidden_states
.
dtype
),
hidden_states
=
a_q
,
hidden_states_scale
=
a_sf_t
,
gemm1_weights
=
self
.
w13_weight
,
gemm1_weights_scale
=
self
.
w13_weight_scale_inv
,
gemm2_weights
=
self
.
w2_weight
,
gemm2_weights_scale
=
self
.
w2_weight_scale_inv
,
num_experts
=
self
.
num_experts
,
top_k
=
self
.
top_k
,
n_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
,
intermediate_size
=
self
.
w2_weight
.
shape
[
2
],
local_expert_offset
=
self
.
start_expert_id
,
local_num_experts
=
self
.
num_experts_per_partition
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
tile_tokens_dim
=
_get_tile_tokens_dim
(
hidden_states
.
shape
[
0
],
self
.
top_k
,
self
.
num_experts
),
routing_method_type
=
2
,
# DeepSeek-styled routing method
use_shuffled_weight
=
False
,
)
def
get_moe_impl_class
():
def
get_moe_impl_class
():
if
global_server_args_dict
[
"enable_deepep_moe"
]:
if
global_server_args_dict
[
"enable_deepep_moe"
]:
return
DeepEPMoE
return
DeepEPMoE
if
global_server_args_dict
[
"enable_flashinfer_moe"
]:
if
global_server_args_dict
[
"enable_flashinfer_
cutlass_
moe"
]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return
FusedMoE
return
FusedMoE
if
use_flashinfer_trtllm_moe
:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return
FlashInferEPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]:
if
global_server_args_dict
[
"enable_ep_moe"
]:
return
EPMoE
return
EPMoE
return
FusedMoE
return
FusedMoE
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
85486b6f
...
@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module):
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_moe
:
Optional
[
bool
]
=
False
,
enable_flashinfer_
cutlass_
moe
:
Optional
[
bool
]
=
False
,
enable_ep_moe
:
Optional
[
bool
]
=
False
,
enable_ep_moe
:
Optional
[
bool
]
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module):
...
@@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module):
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
self
.
expert_map
=
None
self
.
expert_map
=
None
if
enable_flashinfer_moe
and
quant_config
is
None
:
if
enable_flashinfer_
cutlass_
moe
and
quant_config
is
None
:
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
enable_flashinfer_moe
=
False
enable_flashinfer_
cutlass_
moe
=
False
enable_ep_moe
=
False
enable_ep_moe
=
False
self
.
enable_flashinfer_moe
=
enable_flashinfer_moe
self
.
enable_flashinfer_
cutlass_
moe
=
enable_flashinfer_
cutlass_
moe
if
enable_ep_moe
:
if
enable_ep_moe
:
assert
(
assert
(
self
.
enable_flashinfer_moe
self
.
enable_flashinfer_
cutlass_
moe
),
"FusedMoE only supports EP with --enable-flashinfer-moe"
),
"FusedMoE only supports EP with --enable-flashinfer-
cutlass-
moe"
self
.
ep_size
=
self
.
tp_size
self
.
ep_size
=
self
.
tp_size
self
.
ep_rank
=
self
.
tp_rank
self
.
ep_rank
=
self
.
tp_rank
self
.
tp_size
=
1
self
.
tp_size
=
1
...
@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module):
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
if
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
:
if
self
.
quant_method
.
__class__
.
__name__
==
"ModelOptNvFp4FusedMoEMethod"
:
self
.
quant_method
.
enable_flashinfer_moe
=
self
.
enable_flashinfer_moe
self
.
quant_method
.
enable_flashinfer_cutlass_moe
=
(
self
.
enable_flashinfer_cutlass_moe
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
85486b6f
...
@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" quantization. Please use Blackwell and"
" above."
" above."
)
)
self
.
enable_flashinfer_moe
=
False
self
.
enable_flashinfer_
cutlass_
moe
=
False
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
w13_weight_scale_2
=
layer
.
w13_weight_scale_2
[:,
0
]
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
layer
.
w13_weight_scale_2
=
Parameter
(
w13_weight_scale_2
,
requires_grad
=
False
)
if
self
.
enable_flashinfer_moe
:
if
self
.
enable_flashinfer_
cutlass_
moe
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
().
to
(
torch
.
float32
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
else
:
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
w13_input_scale
=
layer
.
w13_input_scale
.
max
(
dim
=
1
).
values
.
to
(
torch
.
float32
)
...
@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
# GEMM 2
# GEMM 2
if
self
.
enable_flashinfer_moe
:
if
self
.
enable_flashinfer_
cutlass_
moe
:
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
w2_input_scale
=
layer
.
w2_input_scale
.
max
().
to
(
torch
.
float32
)
else
:
else
:
w2_input_scale
=
layer
.
w2_input_scale
w2_input_scale
=
layer
.
w2_input_scale
...
@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@
property
@
property
def
load_up_proj_weight_first
(
self
)
->
bool
:
def
load_up_proj_weight_first
(
self
)
->
bool
:
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
return
self
.
enable_flashinfer_moe
return
self
.
enable_flashinfer_
cutlass_
moe
def
apply
(
def
apply
(
self
,
self
,
...
@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
self
.
enable_flashinfer_moe
:
if
self
.
enable_flashinfer_
cutlass_
moe
:
assert
(
assert
(
not
apply_router_weight_on_input
not
apply_router_weight_on_input
),
"apply_router_weight_on_input is not supported for Flashinfer"
),
"apply_router_weight_on_input is not supported for Flashinfer"
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
85486b6f
...
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_deepep_moe"
,
"enable_deepep_moe"
,
"deepep_mode"
,
"deepep_mode"
,
"enable_ep_moe"
,
"enable_ep_moe"
,
"enable_flashinfer_moe"
,
"enable_flashinfer_cutlass_moe"
,
"enable_flashinfer_trtllm_moe"
,
"enable_flashinfer_allreduce_fusion"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"ep_dispatch_algorithm"
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
85486b6f
...
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
...
@@ -56,7 +56,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
(
DeepEPMoE
,
get_moe_impl_class
,
use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
...
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
...
@@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module):
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
config
=
config
,
prefix
=
add_prefix
(
"gate"
,
prefix
),
is_nextn
=
is_nextn
)
)
self
.
topk
=
TopK
(
self
.
topk
=
(
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
TopK
(
renormalize
=
config
.
norm_topk_prob
,
top_k
=
config
.
num_experts_per_tok
+
self
.
num_fused_shared_experts
,
use_grouped_topk
=
True
,
renormalize
=
config
.
norm_topk_prob
,
num_expert_group
=
config
.
n_group
,
use_grouped_topk
=
True
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
topk_group
=
config
.
topk_group
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
)
if
not
use_flashinfer_trtllm_moe
else
None
)
)
self
.
experts
=
get_moe_impl_class
()(
self
.
experts
=
get_moe_impl_class
()(
...
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
...
@@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module):
# Additional args for FusedMoE
# Additional args for FusedMoE
**
(
**
(
dict
(
dict
(
enable_flashinfer_moe
=
True
,
enable_flashinfer_
cutlass_
moe
=
True
,
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
)
)
if
global_server_args_dict
[
"enable_flashinfer_moe"
]
if
global_server_args_dict
[
"enable_flashinfer_cutlass_moe"
]
else
{}
),
**
(
dict
(
renormalize
=
config
.
norm_topk_prob
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_fused_shared_experts
=
self
.
num_fused_shared_experts
,
topk_group
=
config
.
topk_group
,
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
)
if
use_flashinfer_trtllm_moe
else
{}
else
{}
),
),
)
)
...
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
final_hidden_states
=
self
.
experts
(
if
self
.
topk
is
not
None
:
hidden_states
=
hidden_states
,
topk_output
=
topk_output
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
:
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
...
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
kwargs
=
{
"hidden_states"
:
hidden_states
}
final_hidden_states
=
self
.
experts
(
if
self
.
topk
is
not
None
:
hidden_states
=
hidden_states
,
topk_output
=
topk_output
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
)
else
:
kwargs
[
"router_logits"
]
=
router_logits
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
# fused in biased_grouped_topk so we can skip here
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
85486b6f
...
@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
...
@@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE
# Additional args for FusedMoE
**
(
**
(
dict
(
dict
(
enable_flashinfer_moe
=
True
,
enable_flashinfer_
cutlass_
moe
=
True
,
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
)
)
if
global_server_args_dict
[
"enable_flashinfer_moe"
]
if
global_server_args_dict
[
"enable_flashinfer_
cutlass_
moe"
]
else
{}
else
{}
),
),
)
)
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
85486b6f
...
@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
...
@@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Additional args for FusedMoE
# Additional args for FusedMoE
**
(
**
(
dict
(
dict
(
enable_flashinfer_moe
=
True
,
enable_flashinfer_
cutlass_
moe
=
True
,
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
enable_ep_moe
=
global_server_args_dict
[
"enable_ep_moe"
],
)
)
if
global_server_args_dict
[
"enable_flashinfer_moe"
]
if
global_server_args_dict
[
"enable_flashinfer_
cutlass_
moe"
]
else
{}
else
{}
),
),
)
)
...
...
python/sglang/srt/server_args.py
View file @
85486b6f
...
@@ -169,7 +169,8 @@ class ServerArgs:
...
@@ -169,7 +169,8 @@ class ServerArgs:
ep_size
:
int
=
1
ep_size
:
int
=
1
enable_ep_moe
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_flashinfer_moe
:
bool
=
False
enable_flashinfer_cutlass_moe
:
bool
=
False
enable_flashinfer_trtllm_moe
:
bool
=
False
enable_flashinfer_allreduce_fusion
:
bool
=
False
enable_flashinfer_allreduce_fusion
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
ep_num_redundant_experts
:
int
=
0
ep_num_redundant_experts
:
int
=
0
...
@@ -428,12 +429,16 @@ class ServerArgs:
...
@@ -428,12 +429,16 @@ class ServerArgs:
),
"Please enable dp attention when setting enable_dp_lm_head. "
),
"Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
# MoE kernel
if
self
.
enable_flashinfer_moe
:
if
self
.
enable_flashinfer_
cutlass_
moe
:
assert
(
assert
(
self
.
quantization
==
"modelopt_fp4"
self
.
quantization
==
"modelopt_fp4"
),
"modelopt_fp4 quantization is required for Flashinfer MOE"
),
"modelopt_fp4 quantization is required for Flashinfer MOE"
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
os
.
environ
[
"TRTLLM_ENABLE_PDL"
]
=
"1"
if
self
.
enable_flashinfer_trtllm_moe
:
assert
self
.
enable_ep_moe
,
"EP MoE is required for Flashinfer TRTLLM MOE"
logger
.
warning
(
f
"Flashinfer TRTLLM MoE is enabled."
)
# DeepEP MoE
# DeepEP MoE
if
self
.
enable_deepep_moe
:
if
self
.
enable_deepep_moe
:
if
self
.
deepep_mode
==
"normal"
:
if
self
.
deepep_mode
==
"normal"
:
...
@@ -1293,10 +1298,15 @@ class ServerArgs:
...
@@ -1293,10 +1298,15 @@ class ServerArgs:
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-moe"
,
"--enable-flashinfer-
cutlass-
moe"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe"
,
help
=
"Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe"
,
)
)
parser
.
add_argument
(
"--enable-flashinfer-trtllm-moe"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer-allreduce-fusion"
,
"--enable-flashinfer-allreduce-fusion"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment