Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
39d28108
Unverified
Commit
39d28108
authored
Nov 30, 2025
by
Omer Ullman Argov
Committed by
GitHub
Nov 30, 2025
Browse files
[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)
parent
cd719de5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
98 additions
and
22 deletions
+98
-22
tests/kernels/moe/test_flashinfer_moe.py
tests/kernels/moe/test_flashinfer_moe.py
+17
-7
tests/kernels/moe/utils.py
tests/kernels/moe/utils.py
+10
-1
tests/kernels/utils.py
tests/kernels/utils.py
+7
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-3
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+55
-10
No files found.
tests/kernels/moe/test_flashinfer_moe.py
View file @
39d28108
...
...
@@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts
,
is_valid_flashinfer_cutlass_fused_moe
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize
import
(
create_flashinfer_prepare_finalize
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
...
...
@@ -48,9 +48,10 @@ MNK_FACTORS = [
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu_and_mul"
,
"relu2"
])
@
torch
.
inference_mode
()
def
test_flashinfer_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
activation
:
str
):
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
...
...
@@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_blocksize
=
16
is_gated_act
=
activation
==
"silu_and_mul"
w1_q
,
w2_q
,
quant_config
=
make_test_quant_config
(
e
,
...
...
@@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
quant_dtype
=
"nvfp4"
,
block_shape
=
None
,
per_act_token_quant
=
False
,
make_gate
=
is_gated_act
,
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
...
...
@@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
assert
is_valid_flashinfer_cutlass_fused_moe
(
a
,
w1_q
,
w2_q
)
flashinfer_experts
=
FusedMoEModularKernel
(
MoEP
repare
AndF
inalize
NoEP
(
),
create_flashinfer_p
repare
_f
inalize
(
use_dp
=
False
,
use_nvfp4
=
True
),
FlashInferExperts
(
out_dtype
=
dtype
,
quant_config
=
quant_config
),
)
fi_activation
=
{
"silu_and_mul"
:
"silu"
,
"relu2"
:
"relu2_no_mul"
}[
activation
]
flashinfer_output
=
flashinfer_experts
(
hidden_states
=
a
,
w1
=
w1_q
,
w2
=
w2_q
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
fi_activation
,
)
# Reference check:
...
...
@@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size
=
quant_blocksize
,
)
w1_d
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w1_d
=
torch
.
empty
(
(
e
,
(
2
if
is_gated_act
else
1
)
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_d
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
for
idx
in
range
(
0
,
e
):
...
...
@@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size
=
quant_blocksize
,
)
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
)
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
activation
=
activation
)
torch
.
testing
.
assert_close
(
torch_output
,
flashinfer_output
,
atol
=
1e-1
,
rtol
=
1e-1
...
...
tests/kernels/moe/utils.py
View file @
39d28108
...
...
@@ -264,13 +264,20 @@ def make_test_weights(
quant_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
per_out_ch_quant
:
bool
=
False
,
make_gate
:
bool
=
True
,
)
->
tuple
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
],
]:
return
(
make_test_weight
(
e
,
2
*
n
,
k
,
in_dtype
,
quant_dtype
,
block_shape
,
per_out_ch_quant
e
,
(
2
if
make_gate
else
1
)
*
n
,
k
,
in_dtype
,
quant_dtype
,
block_shape
,
per_out_ch_quant
,
),
make_test_weight
(
e
,
k
,
n
,
in_dtype
,
quant_dtype
,
block_shape
,
per_out_ch_quant
),
)
...
...
@@ -297,6 +304,7 @@ def make_test_quant_config(
quant_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
list
[
int
]
|
None
=
None
,
make_gate
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
FusedMoEQuantConfig
]:
(
_
,
w1
,
w1_s
,
w1_gs
),
(
_
,
w2
,
w2_s
,
w2_gs
)
=
make_test_weights
(
e
,
...
...
@@ -306,6 +314,7 @@ def make_test_quant_config(
quant_dtype
,
per_out_ch_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
make_gate
=
make_gate
,
)
# Hacky/trivial scales for nvfp4.
...
...
tests/kernels/utils.py
View file @
39d28108
...
...
@@ -14,6 +14,7 @@ from torch._prims_common import TensorLikeType
from
tests.kernels.quant_utils
import
native_w8a8_block_matmul
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.utils.torch_utils
import
make_tensor_with_pad
...
...
@@ -839,6 +840,7 @@ def torch_experts(
per_act_token_quant
=
False
,
block_shape
:
list
[
int
]
|
None
=
None
,
apply_router_weights_on_input
:
bool
=
False
,
activation
:
str
=
"silu_and_mul"
,
)
->
torch
.
Tensor
:
assert
(
global_num_experts
==
-
1
...
...
@@ -881,6 +883,8 @@ def torch_experts(
f32
=
torch
.
float32
act
=
CustomOp
.
op_registry
[
activation
]
for
i
in
range
(
num_experts
):
mask
=
topk_ids
==
i
if
mask
.
sum
():
...
...
@@ -888,7 +892,7 @@ def torch_experts(
tmp1
=
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
)
if
b_bias1
is
not
None
:
tmp1
=
tmp1
+
b_bias1
[
i
].
view
(
1
,
-
1
).
to
(
tmp1
.
dtype
)
tmp2
=
SiluAndMul
()(
tmp1
)
tmp2
=
act
()(
tmp1
)
out
[
mask
]
=
tmp2
@
w2
[
i
].
transpose
(
0
,
1
)
if
b_bias2
is
not
None
:
out
[
mask
]
=
out
[
mask
]
+
b_bias2
[
i
].
view
(
1
,
-
1
).
to
(
tmp1
.
dtype
)
...
...
@@ -969,6 +973,7 @@ def torch_moe(
b_bias2
:
torch
.
Tensor
|
None
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
activation
:
str
=
"silu_and_mul"
,
)
->
torch
.
Tensor
:
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
...
...
@@ -982,6 +987,7 @@ def torch_moe(
b_bias1
,
b_bias2
,
expert_map
,
activation
=
activation
,
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
39d28108
...
...
@@ -600,14 +600,20 @@ class FusedMoE(CustomOp):
# Avoid circular import
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptFp8MoEMethod
,
ModelOptNvFp4FusedMoE
,
)
if
not
isinstance
(
self
.
quant_method
,
(
UnquantizedFusedMoEMethod
,
ModelOptFp8MoEMethod
)
self
.
quant_method
,
(
UnquantizedFusedMoEMethod
,
ModelOptFp8MoEMethod
,
ModelOptNvFp4FusedMoE
,
),
):
raise
NotImplementedError
(
"is_act_and_mul=False is supported only for unquantized "
"
and
ModelOpt FP8
moe for now
"
"
,
ModelOpt FP8
, and ModelOpt NvFp4 checkpoints
"
)
if
not
current_platform
.
is_cuda
():
raise
NotImplementedError
(
...
...
@@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp):
self
.
_load_combined_w13_weight_scale
(
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
param
=
param
,
param
=
expert_data
,
tp_rank
=
self
.
tp_rank
,
)
return
True
if
return_success
else
None
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
39d28108
...
...
@@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
(
2
if
self
.
moe
.
is_act_and_mul
else
1
)
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
2
,
dtype
=
weight_dtype
,
...
...
@@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_weight_scale
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
(
2
if
self
.
moe
.
is_act_and_mul
else
1
)
*
intermediate_size_per_partition
,
# 2 fp4 items are packed in the input dimension
hidden_size
//
self
.
quant_config
.
group_size
,
dtype
=
weight_scale_dtype
,
...
...
@@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
w13_weight_scale_2
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
num_experts
,
2
if
self
.
moe
.
is_act_and_mul
else
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale_2"
,
w13_weight_scale_2
)
...
...
@@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_scale_num_experts
=
global_num_experts
if
use_global_sf
else
num_experts
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
global_scale_num_experts
,
2
,
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
global_scale_num_experts
,
2
if
self
.
moe
.
is_act_and_mul
else
1
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
...
...
@@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
gemm1_weight
=
layer
.
w13_weight
.
data
gemm1_weight_scale
=
layer
.
w13_weight_scale
.
data
if
self
.
allow_flashinfer
and
(
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
if
(
self
.
allow_flashinfer
and
(
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
)
and
self
.
moe
.
is_act_and_mul
):
gemm1_weight
,
gemm1_weight_scale
=
reorder_w1w3_to_w3w1
(
gemm1_weight
,
gemm1_weight_scale
,
dim
=-
2
...
...
@@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer
.
w13_weight_scale
=
Parameter
(
gemm1_weight_scale
,
requires_grad
=
False
)
# Common processing for w13_weight_scale_2
if
not
torch
.
allclose
(
if
self
.
moe
.
is_act_and_mul
and
not
torch
.
allclose
(
layer
.
w13_weight_scale_2
[:,
0
],
layer
.
w13_weight_scale_2
[:,
1
]
):
logger
.
warning_once
(
...
...
@@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w13_blockscale_swizzled
,
requires_grad
=
False
)
w13_weight
=
layer
.
w13_weight
intermediate_size_pad
=
w13_blockscale_swizzled
.
size
(
1
)
-
w13_weight
.
size
(
1
)
if
intermediate_size_pad
:
# padding gated activations will require to split w1 and w3
# and pad them individually
assert
not
self
.
moe
.
is_act_and_mul
,
(
"The intermediate size required padding, "
"but padding is not implemented for gated activations"
)
layer
.
w13_weight
=
Parameter
(
torch
.
nn
.
functional
.
pad
(
w13_weight
,
(
0
,
0
,
0
,
intermediate_size_pad
)
),
requires_grad
=
False
,
)
layer
.
w2_weight
=
Parameter
(
torch
.
nn
.
functional
.
pad
(
layer
.
w2_weight
,
(
0
,
intermediate_size_pad
//
2
,
0
,
0
)
),
requires_grad
=
False
,
)
layer
.
w2_weight_scale
=
Parameter
(
torch
.
nn
.
functional
.
pad
(
layer
.
w2_weight_scale
,
(
0
,
intermediate_size_pad
//
16
)
),
requires_grad
=
False
,
)
w2_blockscale_swizzled
=
swizzle_blockscale
(
layer
.
w2_weight_scale
)
layer
.
w2_weight_scale
=
Parameter
(
w2_blockscale_swizzled
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
...
...
@@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logical_to_physical_map
:
torch
.
Tensor
|
None
=
None
,
logical_replica_count
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
not
self
.
moe
.
is_act_and_mul
:
assert
(
self
.
allow_flashinfer
and
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
CUTLASS
),
(
"Non-gated activations are only supported by the"
" flashinfer CUTLASS backend for modelopt checkpoints"
)
if
(
self
.
allow_flashinfer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment