Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
03ee4811
Unverified
Commit
03ee4811
authored
Nov 16, 2025
by
amirkl94
Committed by
GitHub
Nov 16, 2025
Browse files
Feature: Support Relu2 in FusedMoE fp8 cutlass path (#27261)
parent
5a87076d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
20 deletions
+42
-20
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+13
-5
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+9
-2
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+20
-13
No files found.
tests/kernels/moe/test_flashinfer.py
View file @
03ee4811
...
...
@@ -77,10 +77,14 @@ class TestData:
@
staticmethod
def
make_moe_tensors_8bit
(
m
:
int
,
k
:
int
,
n
:
int
,
e
:
int
,
reorder
:
bool
m
:
int
,
k
:
int
,
n
:
int
,
e
:
int
,
reorder
:
bool
,
activation
:
str
=
"silu"
)
->
"TestData"
:
is_gated
=
activation
!=
"relu2_no_mul"
hidden_states
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
w13
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w13
=
torch
.
randn
(
(
e
,
(
2
*
n
)
if
is_gated
else
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Scale to fp8
...
...
@@ -190,18 +194,22 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu"
,
"relu2_no_mul"
])
def
test_flashinfer_cutlass_moe_fp8_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
activation
:
str
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
False
)
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
False
,
activation
=
activation
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
topk_weights
,
topk_ids
,
_
=
FusedMoE
.
select_experts
(
...
...
@@ -233,7 +241,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
activation
=
"silu"
,
activation
=
activation
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
...
...
@@ -253,7 +261,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td
.
layer
,
topk_weights
,
topk_ids
,
activation
=
"silu"
,
activation
=
activation
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
03ee4811
...
...
@@ -148,8 +148,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
|
None
,
):
assert
activation
==
"silu"
,
(
"Only activation silu is supported in FlashInferExperts"
from
flashinfer.fused_moe.core
import
ActivationType
activation_str_to_value_map
=
{
"silu"
:
ActivationType
.
Swiglu
,
# This is the default
"relu2_no_mul"
:
ActivationType
.
Relu2
,
}
assert
activation
in
activation_str_to_value_map
,
(
f
"
{
activation
=
}
missing from
{
activation_str_to_value_map
.
keys
()
=
}
"
)
# Select quantization metadata based on FP8 format/path
...
...
@@ -215,6 +221,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
ep_size
=
self
.
ep_size
,
ep_rank
=
self
.
ep_rank
,
output
=
output
,
activation_type
=
activation_str_to_value_map
[
activation
],
# Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale
=
self
.
use_deepseek_fp8_block_scale
,
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
03ee4811
...
...
@@ -354,12 +354,18 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
if
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()
and
self
.
moe
.
is_act_and_mul
):
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
():
self
.
flashinfer_moe_backend
=
get_flashinfer_moe_backend
()
if
(
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
and
not
self
.
moe
.
is_act_and_mul
):
logger
.
info_once
(
"Non-gated MoE is not supported for min-latency mode,"
"falling back to high-throughput mode"
)
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
CUTLASS
logger
.
info_once
(
f
"Using FlashInfer
{
self
.
flashinfer_moe_backend
.
value
}
kernels"
)
...
...
@@ -557,10 +563,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if
self
.
flashinfer_moe_backend
is
not
None
:
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
register_moe_scaling_factors
(
layer
)
if
self
.
moe
.
is_act_and_mul
:
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
register_moe_scaling_factors
(
layer
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -570,13 +577,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
g1_alphas
=
(
layer
.
w13_weight_scale
*
layer
.
w13_input
_scal
e
)
.
squeeze
(),
g1_alphas
=
layer
.
output1_scales_gate
_scal
ar
.
squeeze
(),
w2_scale
=
layer
.
w2_weight_scale
,
g2_alphas
=
(
layer
.
w2_weight_scale
*
layer
.
w2_input
_scal
e
)
.
squeeze
(),
g2_alphas
=
layer
.
output2_scales
_scal
ar
.
squeeze
(),
a1_scale
=
layer
.
w13_input_scale
,
a1_gscale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_gscale
=
1.0
/
layer
.
w2_input_scale
,
a2_gscale
=
layer
.
w2_input_scale
_inv
,
per_act_token_quant
=
False
,
)
...
...
@@ -642,9 +649,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
)
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
:
assert
not
renormalize
assert
activation
==
"
silu
"
,
(
f
"
Expected 'silu' activation
but got
{
activation
}
"
assert
activation
in
(
"silu"
,
"relu2_no_mul"
),
(
"Expected
activation
to be in ('
silu
'
,
'relu2_no_mul'),"
f
"but got
{
activation
}
"
)
return
flashinfer_cutlass_moe_fp8
(
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment