Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7b926e89
Unverified
Commit
7b926e89
authored
Dec 22, 2025
by
Yongye Zhu
Committed by
GitHub
Dec 22, 2025
Browse files
[MoE Refactor][9/N] Use modular kernel for unquantized Triton MoE (#31052)
Signed-off-by:
Yongye Zhu
<
zyy1102000@gmail.com
>
parent
ab3a85fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
7 deletions
+22
-7
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+7
-0
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+15
-7
No files found.
tests/kernels/moe/test_moe.py
View file @
7b926e89
...
@@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
...
@@ -60,6 +60,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_w
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.v1.worker.workspace
import
init_workspace_manager
NUM_EXPERTS
=
[
8
,
64
,
192
]
NUM_EXPERTS
=
[
8
,
64
,
192
]
EP_SIZE
=
[
1
,
4
]
EP_SIZE
=
[
1
,
4
]
...
@@ -487,6 +488,7 @@ def test_mixtral_moe(
...
@@ -487,6 +488,7 @@ def test_mixtral_moe(
monkeypatch
.
setenv
(
"MASTER_ADDR"
,
"localhost"
)
monkeypatch
.
setenv
(
"MASTER_ADDR"
,
"localhost"
)
monkeypatch
.
setenv
(
"MASTER_PORT"
,
"12345"
)
monkeypatch
.
setenv
(
"MASTER_PORT"
,
"12345"
)
init_distributed_environment
()
init_distributed_environment
()
init_workspace_manager
(
torch
.
cuda
.
current_device
())
# Instantiate our and huggingface's MoE blocks
# Instantiate our and huggingface's MoE blocks
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
...
@@ -533,6 +535,11 @@ def test_mixtral_moe(
...
@@ -533,6 +535,11 @@ def test_mixtral_moe(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# FIXME (zyongye) fix this after we move self.kernel
# assignment in FusedMoE.__init__
vllm_moe
.
experts
.
quant_method
.
process_weights_after_loading
(
vllm_moe
.
experts
)
# Run forward passes for both MoE blocks
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
7b926e89
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
@@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -23,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoEPrepareAndFinalize
,
)
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
...
@@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
...
@@ -30,9 +34,9 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_batched_moe
import
BatchedTritonExperts
from
.fused_moe
import
TritonExperts
,
fused_experts
from
.fused_moe
import
TritonExperts
else
:
else
:
fused_e
xperts
=
None
# type: ignore
TritonE
xperts
=
None
# type: ignore
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
from
.moe_pallas
import
fused_moe
as
fused_moe_pallas
...
@@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -265,6 +269,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
else
:
else
:
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
layer
.
cpu_fused_moe
=
cpu_fused_moe
.
CPUFusedMOE
(
layer
)
elif
current_platform
.
is_cuda_alike
():
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
self
.
moe_quant_config
),
shared_experts
=
None
,
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -278,9 +289,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits
=
router_logits
,
router_logits
=
router_logits
,
)
)
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
:
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
moe
.
has_bias
:
if
self
.
moe
.
has_bias
:
return
biased_moe_quant_config
(
return
biased_moe_quant_config
(
layer
.
w13_bias
,
layer
.
w13_bias
,
...
@@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -322,7 +331,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
)
)
else
:
else
:
result
=
fused_experts
(
result
=
self
.
kernel
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
...
@@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -330,7 +339,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
layer
.
activation
,
activation
=
layer
.
activation
,
quant_config
=
self
.
moe_quant_config
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
expert_map
=
layer
.
expert_map
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment