Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3fd1fb0b
Unverified
Commit
3fd1fb0b
authored
Nov 28, 2025
by
Huamin Li
Committed by
GitHub
Nov 28, 2025
Browse files
Revert "[LoRA] Support FusedMoE LoRA Triton kernel for mxfp4 (#28971)" (#29697)
Signed-off-by:
Huamin Li
<
3ericli@gmail.com
>
parent
a51f4186
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
440 deletions
+11
-440
tests/kernels/moe/test_modular_oai_triton_moe.py
tests/kernels/moe/test_modular_oai_triton_moe.py
+0
-250
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+9
-26
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+0
-146
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+2
-18
No files found.
tests/kernels/moe/test_modular_oai_triton_moe.py
deleted
100644 → 0
View file @
a51f4186
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test modular OAI Triton MoE
"""
import
pytest
import
torch
from
vllm.utils.import_utils
import
has_triton_kernels
if
not
has_triton_kernels
():
pytest
.
skip
(
"triton_kernels not found, skipping all related tests"
,
allow_module_level
=
True
,
)
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
from
triton_kernels.numerics
import
InFlexData
from
triton_kernels.numerics_details.mxfp
import
downcast_to_mxfp
,
upcast_from_mxfp
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.testing
import
assert_close
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
OAITritonExperts
,
UnfusedOAITritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
FusedMoEModularKernel
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.utils
import
shuffle_weight
from
vllm.platforms
import
current_platform
MNK
=
[
(
1
,
512
,
384
),
(
1
,
2880
,
2880
),
(
2
,
512
,
384
),
(
2
,
2880
,
2880
),
(
32
,
2880
,
2880
),
(
64
,
2880
,
2880
),
]
def
unshuffle_weight
(
w
:
torch
.
Tensor
):
first
=
w
[...,
::
2
]
second
=
w
[...,
1
::
2
]
return
torch
.
concat
((
first
,
second
),
dim
=-
1
)
def
make_weights
(
dtype
,
k
,
n
,
e
):
w1
=
torch
.
randn
((
e
,
k
,
2
*
n
),
dtype
=
dtype
,
device
=
"cuda"
)
w1_bias
=
torch
.
randn
((
e
,
2
*
n
),
dtype
=
dtype
,
device
=
"cuda"
)
w2
=
torch
.
randn
((
e
,
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
w2_bias
=
torch
.
randn
((
e
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
w1_tri
=
w1
.
clone
()
w2_tri
=
w2
.
clone
()
w1_bias_tri
=
w1_bias
.
clone
()
w2_bias_tri
=
w2_bias
.
clone
()
w1_bias_tri
=
w1_bias_tri
.
to
(
torch
.
float32
)
w2_bias_tri
=
w2_bias_tri
.
to
(
torch
.
float32
)
# shuffle weights
w1_tri
=
shuffle_weight
(
w1_tri
)
w1_bias_tri
=
shuffle_weight
(
w1_bias_tri
)
# quant triton_weights
w1_tri
,
w1_scale_tri
=
downcast_to_mxfp
(
w1_tri
,
torch
.
uint8
,
axis
=
1
)
w1
=
upcast_from_mxfp
(
w1_tri
,
w1_scale_tri
,
dtype
,
axis
=
1
)
w1
=
unshuffle_weight
(
w1
)
w2_tri
,
w2_scale_tri
=
downcast_to_mxfp
(
w2_tri
,
torch
.
uint8
,
axis
=
1
)
w2
=
upcast_from_mxfp
(
w2_tri
,
w2_scale_tri
,
dtype
,
axis
=
1
)
num_warps
=
8
w_layout
,
w_layout_opts
=
layout
.
make_default_matmul_mxfp4_w_layout
(
mx_axis
=
1
)
w_scale_layout
,
w_scale_layout_opts
=
(
layout
.
make_default_matmul_mxfp4_w_scale_layout
(
mx_axis
=
1
,
num_warps
=
num_warps
)
)
w1_tri
=
convert_layout
(
wrap_torch_tensor
(
w1_tri
,
FP4
),
w_layout
,
**
w_layout_opts
)
w1_scale_tri
=
convert_layout
(
wrap_torch_tensor
(
w1_scale_tri
),
w_scale_layout
,
**
w_scale_layout_opts
,
)
w2_tri
=
convert_layout
(
wrap_torch_tensor
(
w2_tri
,
FP4
),
w_layout
,
**
w_layout_opts
)
w2_scale_tri
=
convert_layout
(
wrap_torch_tensor
(
w2_scale_tri
),
w_scale_layout
,
**
w_scale_layout_opts
,
)
w1_precision_config
=
PrecisionConfig
(
weight_scale
=
w1_scale_tri
,
flex_ctx
=
FlexCtx
(
rhs_data
=
InFlexData
())
)
w2_precision_config
=
PrecisionConfig
(
weight_scale
=
w2_scale_tri
,
flex_ctx
=
FlexCtx
(
rhs_data
=
InFlexData
())
)
return
(
w1
,
w2
,
w1_bias
,
w2_bias
,
w1_tri
,
w2_tri
,
w1_bias_tri
,
w2_bias_tri
,
w1_precision_config
,
w2_precision_config
,
)
def
swiglu
(
x
,
alpha
:
float
=
1.702
,
limit
:
float
=
1.0
):
# Note we add an extra bias of 1 to the linear layer
x_glu
,
x_linear
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
if
limit
is
not
None
:
x_glu
=
x_glu
.
clamp
(
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
if
limit
is
not
None
:
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
return
out_glu
*
(
x_linear
+
1
)
def
torch_moe_impl
(
hidden_states
:
torch
.
Tensor
,
# (M, K)
w1
:
torch
.
Tensor
,
# (E, K, 2N)
w2
:
torch
.
Tensor
,
# (E, N, K)
w1_bias
:
torch
.
Tensor
,
# (E, 2N)
w2_bias
:
torch
.
Tensor
,
# (E, K)
topk_weights
:
torch
.
Tensor
,
# (M, topk)
topk_ids
:
torch
.
Tensor
,
# (M, topk)
):
w1
=
w1
[
topk_ids
,
...]
w1_bias
=
w1_bias
[
topk_ids
,
...]
hidden_states
=
torch
.
einsum
(
"bekc,bk->bec"
,
w1
,
hidden_states
)
+
w1_bias
hidden_states
=
swiglu
(
hidden_states
,
limit
=
7
)
w2
=
w2
[
topk_ids
,
...]
w2_bias
=
w2_bias
[
topk_ids
,
...]
hidden_states
=
torch
.
einsum
(
"bekc,bek->bec"
,
w2
,
hidden_states
)
+
w2_bias
# Weighted sum of experts
hidden_states
=
torch
.
einsum
(
"bec,be->bc"
,
hidden_states
,
topk_weights
)
return
hidden_states
def
oai_triton_moe_impl
(
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
"PrecisionConfig"
,
w2_scale
:
"PrecisionConfig"
,
w1_bias
:
torch
.
Tensor
|
None
,
w2_bias
:
torch
.
Tensor
|
None
,
num_experts
:
int
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
unfused
:
bool
=
False
,
)
->
torch
.
Tensor
:
quant_config
=
mxfp4_w4a16_moe_quant_config
(
w1_bias
=
w1_bias
,
w2_bias
=
w2_bias
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
)
if
unfused
:
fused_experts
=
UnfusedOAITritonExperts
(
quant_config
)
else
:
fused_experts
=
OAITritonExperts
(
quant_config
)
mk
=
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
fused_experts
)
return
mk
.
forward
(
hidden_states
=
x
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
"swigluoai"
,
global_num_experts
=
num_experts
,
expert_map
=
None
,
apply_router_weight_on_input
=
False
,
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"unfused"
,
[
True
,
False
])
def
test_oai_triton_moe
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
,
num_experts
:
int
,
topk
:
int
,
unfused
:
bool
,
):
current_platform
.
seed_everything
(
0
)
(
w1
,
w2
,
w1_bias
,
w2_bias
,
w1_tri
,
w2_tri
,
w1_bias_tri
,
w2_bias_tri
,
w1_precision_config
,
w2_precision_config
,
)
=
make_weights
(
dtype
,
k
,
n
,
num_experts
)
x
=
torch
.
randn
((
m
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
router_logits
=
torch
.
randn
(
m
,
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
torch
.
topk
(
router_logits
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
topk_weights
=
torch
.
nn
.
functional
.
softmax
(
topk_weights
,
dim
=-
1
)
with
set_current_vllm_config
(
VllmConfig
()):
out_ref
=
torch_moe_impl
(
x
,
w1
,
w2
,
w1_bias
,
w2_bias
,
topk_weights
,
topk_ids
)
out
=
oai_triton_moe_impl
(
x
,
w1_tri
,
w2_tri
,
w1_precision_config
,
w2_precision_config
,
w1_bias_tri
,
w2_bias_tri
,
num_experts
,
topk_weights
,
topk_ids
,
unfused
,
)
assert_close
(
ref
=
out_ref
,
tri
=
out
,
maxtol
=
0.025
,
rmstol
=
0.005
)
vllm/lora/layers/fused_moe.py
View file @
3fd1fb0b
...
...
@@ -20,24 +20,15 @@ from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str
,
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
modular_marlin_fused_moe
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
TritonExperts
,
modular_triton_fused_moe
,
try_get_optimal_moe_config
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe_modular_method
import
(
FusedMoEModularMethod
,
)
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
UnfusedOAITritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
.utils
import
_get_lora_device
...
...
@@ -123,23 +114,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
self
.
base_layer
.
ensure_moe_quant_config_init
()
quant_config
=
self
.
base_layer
.
quant_method
.
moe_quant_config
prepare_finalize
=
MoEPrepareAndFinalizeNoEP
()
m_fused_moe_fn
=
FusedMoEModularKernel
(
prepare_finalize
,
self
.
base_layer
.
quant_method
.
select_gemm_impl
(
prepare_finalize
,
self
.
base_layer
),
self
.
base_layer
.
shared_experts
,
getattr
(
self
.
base_layer
,
"shared_experts_stream"
,
None
),
)
if
quant_config
.
use_mxfp4_w4a16
:
assert
isinstance
(
m_fused_moe_fn
.
fused_experts
,
(
MarlinExperts
,
UnfusedOAITritonExperts
)
m_fused_moe_fn
=
(
modular_triton_fused_moe
(
quant_config
,
shared_experts
=
self
.
base_layer
.
shared_experts
)
else
:
assert
isinstanc
e
(
m_fused_moe_fn
.
fused_experts
,
(
MarlinExperts
,
TritonE
xperts
)
if
not
quant_config
.
use_mxfp4_w4a16
else
modular_marlin_fused_mo
e
(
quant_config
,
shared_experts
=
self
.
base_layer
.
shared_e
xperts
)
)
def
fwd_decorator
(
layer
,
func
):
def
wrapper
(
*
args
,
**
kwargs
):
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
3fd1fb0b
...
...
@@ -5,7 +5,6 @@
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FUSED_MOE_UNQUANTIZED_CONFIG
,
...
...
@@ -377,148 +376,3 @@ class OAITritonExperts(BaseOAITritonExperts):
intermediate_cache
=
workspace2
,
a1q_scale
=
a1q_scale
,
)
class
UnfusedOAITritonExperts
(
BaseOAITritonExperts
):
"""
A Triton based MoE expert class that operates on expert standard
format and explicitly keeps the activation and reduction (moe_sum) steps
unfused from the matmul_ogs kernel. This exposes injection points
for activation and moe_sum.
One use case for it is to inject LoRA modules on the activation and moe_sum.
"""
def
__init__
(
self
,
quant_config
:
FusedMoEQuantConfig
):
# TODO (varun) : Enable activation quantization
assert
quant_config
.
use_mxfp4_w4a16
,
"Supports only mxfp4_w4a16"
super
().
__init__
(
quant_config
)
@
property
def
activation_formats
(
self
,
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
,
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# workspace are allocated inside the kernel
workspace1
=
(
M
*
topk
,
N
//
2
)
workspace2
=
(
M
*
topk
,
max
(
N
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
moe_sum
(
self
,
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
ops
.
moe_sum
(
input
,
output
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
if
self
.
quant_config
is
None
:
self
.
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
routing_data
,
gather_indx
,
scatter_indx
=
self
.
_make_routing_data
(
topk_ids
,
topk_weights
,
local_num_experts
)
topk
=
topk_ids
.
size
(
1
)
# type check, uint8 means mxfp4
assert
hidden_states
.
dtype
==
torch
.
bfloat16
assert
(
self
.
quant_config
.
w1_bias
is
None
or
self
.
quant_config
.
w1_bias
.
dtype
==
torch
.
float32
)
assert
(
self
.
quant_config
.
w2_bias
is
None
or
self
.
quant_config
.
w2_bias
.
dtype
==
torch
.
float32
)
# Shape check, only check non-mxfp4
assert
hidden_states
.
ndim
==
2
assert
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
assert
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
batch_dim
=
1
M
,
K
=
hidden_states
.
shape
E
,
_
,
N
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# Note that the output tensor might be in workspace13
intermediate_cache1
=
_resize_cache
(
workspace2
,
(
batch_dim
,
M
*
topk
,
N
))
intermediate_cache3
=
_resize_cache
(
workspace2
,
(
batch_dim
,
M
*
topk
,
K
))
intermediate_cache2
=
_resize_cache
(
workspace13
,
(
M
*
topk
,
N
//
2
))
gammas
=
routing_data
.
gate_scal
if
routing_data
else
None
matmul_ogs
(
hidden_states
,
w1
,
self
.
quant_config
.
w1_bias
,
routing_data
,
gather_indx
=
gather_indx
,
precision_config
=
self
.
quant_config
.
w1_precision
,
gammas
=
gammas
if
apply_router_weight_on_input
else
None
,
fused_activation
=
None
,
y
=
intermediate_cache1
,
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
)
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
# y[dst_ind // n_expts_act, :] += x[src_ind, :]
# Need to set n_expts_act to 1 to unfuse moe_sum
routing_data
.
n_expts_act
=
1
matmul_ogs
(
intermediate_cache2
,
w2
,
self
.
quant_config
.
w2_bias
,
routing_data
,
scatter_indx
=
scatter_indx
,
precision_config
=
self
.
quant_config
.
w2_precision
,
gammas
=
None
if
apply_router_weight_on_input
else
gammas
,
y
=
intermediate_cache3
,
)
self
.
moe_sum
(
intermediate_cache3
.
view
(
-
1
,
topk
,
K
),
output
)
vllm/model_executor/layers/quantization/mxfp4.py
View file @
3fd1fb0b
...
...
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
)
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
OAITritonExperts
,
UnfusedOAITritonExperts
,
)
from
vllm.model_executor.layers.fused_moe.trtllm_moe
import
TrtLlmGenExperts
from
vllm.model_executor.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
...
...
@@ -84,21 +83,8 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
if
not
current_platform
.
is_cuda
():
return
Mxfp4Backend
.
NONE
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported
=
(
has_triton_kernels
()
and
is_torch_equal_or_newer
(
"2.8.0"
)
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and
(
9
,
0
)
<=
current_platform
.
get_device_capability
()
<
(
11
,
0
)
)
if
envs
.
VLLM_MXFP4_USE_MARLIN
or
not
triton_kernels_supported
:
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using Marlin backend"
)
return
Mxfp4Backend
.
MARLIN
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using Triton backend"
)
return
Mxfp4Backend
.
TRITON
logger
.
info_once
(
"[get_mxfp4_backend_with_lora] Using Marlin backend"
)
return
Mxfp4Backend
.
MARLIN
def
get_mxfp4_backend
(
with_lora_support
:
bool
)
->
Mxfp4Backend
:
...
...
@@ -868,8 +854,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
MARLIN
:
return
MarlinExperts
(
self
.
moe_quant_config
)
elif
self
.
mxfp4_backend
==
Mxfp4Backend
.
TRITON
:
if
self
.
moe
.
is_lora_enabled
:
return
UnfusedOAITritonExperts
(
self
.
moe_quant_config
)
return
OAITritonExperts
(
self
.
moe_quant_config
)
else
:
raise
NotImplementedError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment