Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87e07a6b
Unverified
Commit
87e07a6b
authored
Jan 08, 2026
by
Michael Goin
Committed by
GitHub
Jan 08, 2026
Browse files
Revert "feat(moe): Add is_act_and_mul=False support for Triton MoE kernels" (#31978)
parent
75082432
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
9 additions
and
191 deletions
+9
-191
tests/kernels/moe/test_triton_moe_no_act_mul.py
tests/kernels/moe/test_triton_moe_no_act_mul.py
+0
-129
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+0
-9
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+3
-10
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-8
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+2
-23
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+0
-5
No files found.
tests/kernels/moe/test_triton_moe_no_act_mul.py
deleted
100644 → 0
View file @
75082432
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test for is_act_and_mul=False MoE using Triton.
This tests the code path used by models like Nemotron-H that use
non-fused activations (e.g., relu2_no_mul) instead of SwiGLU-style
fused activations.
This feature is supported on both CUDA and ROCm (with AITER disabled).
"""
import
pytest
import
torch
from
vllm.platforms
import
current_platform
pytestmark
=
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Tests for is_act_and_mul=False MoE require CUDA or ROCm"
,
)
@
pytest
.
fixture
def
disable_aiter_on_rocm
(
monkeypatch
):
"""Fixture to disable AITER on ROCm to use Triton path."""
if
current_platform
.
is_rocm
():
from
vllm._aiter_ops
import
rocm_aiter_ops
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER_MOE"
,
"0"
)
rocm_aiter_ops
.
refresh_env_variables
()
yield
rocm_aiter_ops
.
refresh_env_variables
()
else
:
# On CUDA, no special setup needed
yield
@
pytest
.
fixture
def
init_workspace
():
"""Initialize workspace manager for MoE tests."""
from
vllm.v1.worker.workspace
import
(
init_workspace_manager
,
reset_workspace_manager
,
)
torch
.
manual_seed
(
42
)
init_workspace_manager
(
torch
.
cuda
.
current_device
())
yield
reset_workspace_manager
()
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"relu2_no_mul"
,
"silu_no_mul"
,
"gelu_no_mul"
])
@
torch
.
inference_mode
()
def
test_moe_no_act_mul
(
disable_aiter_on_rocm
,
init_workspace
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
activation
:
str
,
):
"""Test MoE with is_act_and_mul=False using Triton."""
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
,
fused_topk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w1
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_config
=
FusedMoEQuantConfig
.
make
(
is_act_and_mul
=
False
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
True
)
fused_experts
=
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
quant_config
),
)
output
=
fused_experts
(
hidden_states
=
a
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
)
assert
output
.
shape
==
(
m
,
k
),
f
"Expected shape
{
(
m
,
k
)
}
, got
{
output
.
shape
}
"
assert
not
torch
.
isnan
(
output
).
any
(),
"Output contains NaN"
assert
not
torch
.
isinf
(
output
).
any
(),
"Output contains Inf"
assert
output
.
abs
().
sum
()
>
0
,
"Output is all zeros"
@
torch
.
inference_mode
()
def
test_moe_workspace_shapes_no_act_mul
(
disable_aiter_on_rocm
):
"""Test workspace_shapes returns correct sizes for is_act_and_mul=False."""
from
vllm.model_executor.layers.fused_moe
import
TritonExperts
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
M
,
N
,
K
,
topk
=
64
,
256
,
128
,
2
quant_config
=
FusedMoEQuantConfig
.
make
(
is_act_and_mul
=
False
)
experts
=
TritonExperts
(
quant_config
)
ws1
,
ws2
,
out
=
experts
.
workspace_shapes
(
M
,
N
,
K
,
topk
,
8
,
8
,
None
)
assert
ws1
[
2
]
==
max
(
N
,
K
)
assert
out
==
(
M
,
K
)
vllm/model_executor/layers/fused_moe/config.py
View file @
87e07a6b
...
@@ -201,11 +201,6 @@ class FusedMoEQuantConfig:
...
@@ -201,11 +201,6 @@ class FusedMoEQuantConfig:
_w1
:
FusedMoEQuantDesc
_w1
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
# Whether activation is fused with gate multiplication (SwiGLU-style).
# When True: intermediate_size = N // 2 (gate and up are combined)
# When False: intermediate_size = N (no gate multiplication)
is_act_and_mul
:
bool
=
True
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
"illegal quantization"
"illegal quantization"
...
@@ -444,7 +439,6 @@ class FusedMoEQuantConfig:
...
@@ -444,7 +439,6 @@ class FusedMoEQuantConfig:
w1_zp
:
torch
.
Tensor
|
None
=
None
,
w1_zp
:
torch
.
Tensor
|
None
=
None
,
w2_zp
:
torch
.
Tensor
|
None
=
None
,
w2_zp
:
torch
.
Tensor
|
None
=
None
,
weight_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
weight_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
is_act_and_mul
:
bool
=
True
,
)
->
"FusedMoEQuantConfig"
:
)
->
"FusedMoEQuantConfig"
:
"""
"""
General builder function for a FusedMoEQuantConfig.
General builder function for a FusedMoEQuantConfig.
...
@@ -504,7 +498,6 @@ class FusedMoEQuantConfig:
...
@@ -504,7 +498,6 @@ class FusedMoEQuantConfig:
_w2
=
FusedMoEQuantDesc
(
_w2
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w2_scale
,
g2_alphas
,
w2_zp
,
w2_bias
weight_dtype
,
w_shape
,
w2_scale
,
g2_alphas
,
w2_zp
,
w2_bias
),
),
is_act_and_mul
=
is_act_and_mul
,
)
)
assert
quant_config
.
per_act_token_quant
==
per_act_token_quant
assert
quant_config
.
per_act_token_quant
==
per_act_token_quant
assert
quant_config
.
per_out_ch_quant
==
per_out_ch_quant
assert
quant_config
.
per_out_ch_quant
==
per_out_ch_quant
...
@@ -836,7 +829,6 @@ def awq_marlin_moe_quant_config(
...
@@ -836,7 +829,6 @@ def awq_marlin_moe_quant_config(
def
biased_moe_quant_config
(
def
biased_moe_quant_config
(
w1_bias
:
torch
.
Tensor
|
None
,
w1_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
is_act_and_mul
:
bool
=
True
,
)
->
FusedMoEQuantConfig
:
)
->
FusedMoEQuantConfig
:
"""
"""
Construct a quant config for unquantized activations with biases.
Construct a quant config for unquantized activations with biases.
...
@@ -846,7 +838,6 @@ def biased_moe_quant_config(
...
@@ -846,7 +838,6 @@ def biased_moe_quant_config(
_a2
=
FusedMoEQuantDesc
(),
_a2
=
FusedMoEQuantDesc
(),
_w1
=
FusedMoEQuantDesc
(
bias
=
w1_bias
),
_w1
=
FusedMoEQuantDesc
(
bias
=
w1_bias
),
_w2
=
FusedMoEQuantDesc
(
bias
=
w2_bias
),
_w2
=
FusedMoEQuantDesc
(
bias
=
w2_bias
),
is_act_and_mul
=
is_act_and_mul
,
)
)
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
87e07a6b
...
@@ -871,11 +871,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -871,11 +871,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dp
=
self
.
num_dispatchers
num_dp
=
self
.
num_dispatchers
num_experts
=
local_num_experts
num_experts
=
local_num_experts
max_num_tokens
=
self
.
max_num_tokens
max_num_tokens
=
self
.
max_num_tokens
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
# For non-fused activations: N = intermediate, after act = N
intermediate_size
=
N
//
2
if
self
.
quant_config
.
is_act_and_mul
else
N
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dp
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
intermediate_size
)
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dp
,
(
N
//
2
)
)
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
output
=
(
num_experts
,
max_num_tokens
*
num_dp
,
K
)
return
(
workspace13
,
workspace2
,
output
)
return
(
workspace13
,
workspace2
,
output
)
...
@@ -950,11 +947,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -950,11 +947,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# We can reuse the memory between these because by the time we need
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
# cache3, we're done with cache1
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
E
,
max_num_tokens
,
N
//
2
))
intermediate_size
=
N
//
2
if
self
.
quant_config
.
is_act_and_mul
else
N
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
E
,
max_num_tokens
,
intermediate_size
)
)
# TODO(bnell): should this be done for any quantized type?
# TODO(bnell): should this be done for any quantized type?
if
self
.
quant_config
.
use_fp8_w8a8
:
if
self
.
quant_config
.
use_fp8_w8a8
:
...
@@ -985,7 +978,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -985,7 +978,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO (bnell): use triton utility from batched deep gemm.
# TODO (bnell): use triton utility from batched deep gemm.
self
.
activation
(
self
.
activation
(
activation
,
activation
,
intermediate_cache2
.
view
(
-
1
,
intermediate_size
),
intermediate_cache2
.
view
(
-
1
,
N
//
2
),
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache1
.
view
(
-
1
,
N
),
)
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
87e07a6b
...
@@ -2296,10 +2296,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -2296,10 +2296,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
workspace1
=
(
M
,
topk
,
max
(
N
//
2
,
K
))
# For non-fused activations: N = intermediate, after act = N
intermediate_size
=
N
//
2
if
self
.
quant_config
.
is_act_and_mul
else
N
workspace1
=
(
M
,
topk
,
max
(
intermediate_size
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
workspace2
=
(
M
,
topk
,
max
(
N
,
K
))
output
=
(
M
,
K
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
return
(
workspace1
,
workspace2
,
output
)
...
@@ -2374,10 +2371,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -2374,10 +2371,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note that the output tensor might be in workspace1
# Note that the output tensor might be in workspace1
intermediate_cache1
=
_resize_cache
(
workspace2
,
(
num_tokens
,
top_k_num
,
N
))
intermediate_cache1
=
_resize_cache
(
workspace2
,
(
num_tokens
,
top_k_num
,
N
))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_size
=
N
//
2
if
self
.
quant_config
.
is_act_and_mul
else
N
intermediate_cache2
=
_resize_cache
(
intermediate_cache2
=
_resize_cache
(
workspace13
,
(
num_tokens
*
top_k_num
,
intermediate_size
)
workspace13
,
(
num_tokens
*
top_k_num
,
N
//
2
)
)
)
intermediate_cache3
=
_resize_cache
(
workspace2
,
(
num_tokens
,
top_k_num
,
K
))
intermediate_cache3
=
_resize_cache
(
workspace2
,
(
num_tokens
,
top_k_num
,
K
))
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
87e07a6b
...
@@ -600,15 +600,9 @@ class FusedMoE(CustomOp):
...
@@ -600,15 +600,9 @@ class FusedMoE(CustomOp):
"is_act_and_mul=False is supported only for unquantized "
"is_act_and_mul=False is supported only for unquantized "
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
)
# ROCm without AITER MoE uses Triton which supports
if
not
current_platform
.
is_cuda
():
# is_act_and_mul=False via standard PyTorch ops (F.silu, F.gelu)
rocm_without_aiter_moe
=
(
current_platform
.
is_rocm
()
and
not
rocm_aiter_ops
.
is_fused_moe_enabled
()
)
if
not
current_platform
.
is_cuda
()
and
not
rocm_without_aiter_moe
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"is_act_and_mul=False is supported only for CUDA, or ROCm "
"is_act_and_mul=False is supported only for CUDA for now"
"(when AITER MoE is disabled) for now"
)
)
if
self
.
enable_eplb
and
not
self
.
quant_method
.
supports_eplb
:
if
self
.
enable_eplb
and
not
self
.
quant_method
.
supports_eplb
:
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
87e07a6b
...
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
...
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
math
import
prod
,
sqrt
from
math
import
prod
from
typing
import
final
from
typing
import
final
import
torch
import
torch
...
@@ -575,35 +575,14 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -575,35 +575,14 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def
activation
(
def
activation
(
self
,
activation
:
str
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
self
,
activation
:
str
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
)
->
None
:
# Fused activations (SwiGLU-style): output is half the size of input
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
if
activation
==
"silu"
:
if
activation
==
"silu"
:
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
torch
.
ops
.
_C
.
silu_and_mul
(
output
,
input
)
torch
.
ops
.
_C
.
silu_and_mul
(
output
,
input
)
elif
activation
==
"gelu"
:
elif
activation
==
"gelu"
:
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
torch
.
ops
.
_C
.
gelu_and_mul
(
output
,
input
)
torch
.
ops
.
_C
.
gelu_and_mul
(
output
,
input
)
elif
activation
==
"swigluoai"
:
elif
activation
==
"swigluoai"
:
# alpha = 1.702, limit = 7.0
# alpha = 1.702, limit = 7.0
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
torch
.
ops
.
_C
.
swigluoai_and_mul
(
output
,
input
)
torch
.
ops
.
_C
.
swigluoai_and_mul
(
output
,
input
)
# Non-fused activations (is_act_and_mul=False): output same size as input
elif
activation
==
"silu_no_mul"
:
assert
output
.
size
(
-
1
)
==
input
.
size
(
-
1
)
# Use out= argument to avoid intermediate tensor
torch
.
sigmoid
(
input
,
out
=
output
)
output
.
mul_
(
input
)
elif
activation
==
"gelu_no_mul"
:
assert
output
.
size
(
-
1
)
==
input
.
size
(
-
1
)
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# Use out= and in-place ops to avoid intermediate tensors
output
.
copy_
(
input
).
div_
(
sqrt
(
2
))
torch
.
erf
(
output
,
out
=
output
)
output
.
add_
(
1
).
mul_
(
input
).
mul_
(
0.5
)
elif
activation
==
"relu2_no_mul"
:
assert
output
.
size
(
-
1
)
==
input
.
size
(
-
1
)
# ReLU²: clamp has out=, then in-place square
torch
.
clamp
(
input
,
min
=
0
,
out
=
output
)
output
.
square_
()
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
87e07a6b
...
@@ -299,12 +299,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -299,12 +299,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return
biased_moe_quant_config
(
return
biased_moe_quant_config
(
layer
.
w13_bias
,
layer
.
w13_bias
,
layer
.
w2_bias
,
layer
.
w2_bias
,
is_act_and_mul
=
self
.
moe
.
is_act_and_mul
,
)
)
elif
not
self
.
moe
.
is_act_and_mul
:
# Create a config with is_act_and_mul=False since
# FUSED_MOE_UNQUANTIZED_CONFIG has is_act_and_mul=True
return
FusedMoEQuantConfig
.
make
(
is_act_and_mul
=
False
)
else
:
else
:
return
FUSED_MOE_UNQUANTIZED_CONFIG
return
FUSED_MOE_UNQUANTIZED_CONFIG
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment