Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7cf56a59
Unverified
Commit
7cf56a59
authored
Apr 01, 2026
by
bnellnm
Committed by
GitHub
Apr 01, 2026
Browse files
[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
5e30e9b9
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
41 additions
and
68 deletions
+41
-68
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+8
-37
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+4
-4
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+1
-1
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+10
-10
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-2
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+1
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+4
-4
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+1
-1
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+2
-2
vllm/model_executor/layers/quantization/quark/quark_moe.py
vllm/model_executor/layers/quantization/quark/quark_moe.py
+3
-3
vllm/model_executor/models/transformers/moe.py
vllm/model_executor/models/transformers/moe.py
+2
-0
No files found.
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
7cf56a59
...
...
@@ -3,14 +3,10 @@
import
torch
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
# TODO(bnell):
Add shared + fused combo function? e.g. +
# TODO(bnell):
Remove this entirely
class
SharedFusedMoE
(
FusedMoE
):
"""
A FusedMoE operation that also computes the results of shared experts.
...
...
@@ -23,36 +19,11 @@ class SharedFusedMoE(FusedMoE):
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
self
.
use_overlapped
:
if
self
.
_shared_experts
is
not
None
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if
(
self
.
reduce_results
and
get_tensor_model_parallel_world_size
()
>
1
and
self
.
must_reduce_shared_expert_outputs
()
):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
else
:
shared_out
=
None
fused_out
=
super
().
forward
(
result
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
if
self
.
shared_experts
is
None
:
return
None
,
result
else
:
shared_out
,
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
# ensure early TP reduction of shared expert outputs when required
if
(
shared_out
is
not
None
and
self
.
reduce_results
and
get_tensor_model_parallel_world_size
()
>
1
and
self
.
must_reduce_shared_expert_outputs
()
):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
return
shared_out
,
fused_out
return
result
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
7cf56a59
...
...
@@ -245,7 +245,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
return
self
.
forward
(
layer
=
layer
,
x
=
x
,
...
...
@@ -261,7 +261,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
hidden_states
=
x
,
...
...
@@ -283,7 +283,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
return
self
.
forward_native
(
layer
,
x
,
topk_weights
,
topk_ids
,
shared_experts_input
)
...
...
@@ -293,7 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
if
self
.
unquantized_backend
==
UnquantizedMoeBackend
.
CPU
:
assert
self
.
moe_kernel
is
None
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
7cf56a59
...
...
@@ -811,7 +811,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
7cf56a59
...
...
@@ -483,7 +483,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
# TODO(bnell): Do these need to be called on the hot path?
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7cf56a59
...
...
@@ -355,7 +355,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
x
,
...
...
@@ -603,7 +603,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
...
...
@@ -628,7 +628,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
x
,
...
...
@@ -963,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
x
,
...
...
@@ -987,7 +987,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
...
...
@@ -1127,7 +1127,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
...
...
@@ -1611,7 +1611,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
kernel_backend
==
"Flashinfer"
return
flashinfer_trtllm_mxint4_moe
(
x
=
x
,
...
...
@@ -1638,7 +1638,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
kernel_backend
==
"Marlin"
return
fused_marlin_moe
(
x
,
...
...
@@ -1887,7 +1887,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
...
...
@@ -2502,7 +2502,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
if
layer
.
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet."
...
...
vllm/model_executor/layers/quantization/experts_int8.py
View file @
7cf56a59
...
...
@@ -141,7 +141,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
7cf56a59
...
...
@@ -877,7 +877,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
...
...
@@ -902,7 +902,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
7cf56a59
...
...
@@ -650,7 +650,7 @@ class GGUFMoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
if
layer
.
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
7cf56a59
...
...
@@ -907,7 +907,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
7cf56a59
...
...
@@ -935,7 +935,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
...
...
@@ -960,7 +960,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
...
...
@@ -1419,7 +1419,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
...
...
@@ -1444,7 +1444,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
7cf56a59
...
...
@@ -369,7 +369,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
layer
.
activation
==
MoEActivation
.
SILU
,
(
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
7cf56a59
...
...
@@ -377,7 +377,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
not
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
...
...
@@ -398,7 +398,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
is_monolithic
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply_monolithic
(
...
...
vllm/model_executor/layers/quantization/quark/quark_moe.py
View file @
7cf56a59
...
...
@@ -444,7 +444,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
if
self
.
rocm_aiter_moe_enabled
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
...
...
@@ -634,7 +634,7 @@ class QuarkW4A8Fp8MoEMethod(QuarkMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
)
...
...
@@ -1027,7 +1027,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
if
not
self
.
emulate
:
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
rocm_aiter_fused_experts
,
...
...
vllm/model_executor/models/transformers/moe.py
View file @
7cf56a59
...
...
@@ -94,6 +94,8 @@ def transformers_moe_forward(
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_topk_ids
=
topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE
# TODO(bnell): figure out a way to avoid calling runner directly.
# it is a hack that the weight are being passed via logits.
return
self
.
runner
.
forward
(
hidden_states
.
clone
(),
topk_weights
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment