Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7cf56a59
Unverified
Commit
7cf56a59
authored
Apr 01, 2026
by
bnellnm
Committed by
GitHub
Apr 01, 2026
Browse files
[MoE Refactor] Make SharedExperts class for use with DefaultMoERunner (#35153)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
5e30e9b9
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
515 additions
and
329 deletions
+515
-329
tests/kernels/moe/utils.py
tests/kernels/moe/utils.py
+0
-2
tools/ep_kernels/install_python_libraries.sh
tools/ep_kernels/install_python_libraries.sh
+1
-0
vllm/distributed/elastic_ep/elastic_execute.py
vllm/distributed/elastic_ep/elastic_execute.py
+1
-2
vllm/lora/layers/fused_moe.py
vllm/lora/layers/fused_moe.py
+0
-4
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+9
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+2
-0
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
...model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+3
-5
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+6
-4
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+23
-35
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+42
-35
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+5
-7
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+0
-1
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+5
-7
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+5
-7
vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py
...l_executor/layers/fused_moe/prepare_finalize/deepep_ll.py
+1
-1
vllm/model_executor/layers/fused_moe/router/fused_moe_router.py
...odel_executor/layers/fused_moe/router/fused_moe_router.py
+8
-0
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
...el_executor/layers/fused_moe/runner/default_moe_runner.py
+183
-219
vllm/model_executor/layers/fused_moe/runner/moe_runner.py
vllm/model_executor/layers/fused_moe/runner/moe_runner.py
+4
-0
vllm/model_executor/layers/fused_moe/runner/shared_experts.py
.../model_executor/layers/fused_moe/runner/shared_experts.py
+216
-0
No files found.
tests/kernels/moe/utils.py
View file @
7cf56a59
...
@@ -603,7 +603,6 @@ def make_shared_experts(
...
@@ -603,7 +603,6 @@ def make_shared_experts(
def
modular_triton_fused_moe
(
def
modular_triton_fused_moe
(
moe_config
:
FusedMoEConfig
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
quant_config
:
FusedMoEQuantConfig
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
)
->
FusedMoEKernel
:
)
->
FusedMoEKernel
:
return
FusedMoEKernel
(
return
FusedMoEKernel
(
maybe_make_prepare_finalize
(
maybe_make_prepare_finalize
(
...
@@ -613,6 +612,5 @@ def modular_triton_fused_moe(
...
@@ -613,6 +612,5 @@ def modular_triton_fused_moe(
use_monolithic
=
False
,
use_monolithic
=
False
,
),
),
TritonExperts
(
moe_config
,
quant_config
),
TritonExperts
(
moe_config
,
quant_config
),
shared_experts
,
inplace
=
False
,
inplace
=
False
,
)
)
tools/ep_kernels/install_python_libraries.sh
View file @
7cf56a59
...
@@ -103,6 +103,7 @@ pushd "$WORKSPACE"
...
@@ -103,6 +103,7 @@ pushd "$WORKSPACE"
echo
"Downloading NVSHMEM
${
NVSHMEM_VER
}
for
${
NVSHMEM_SUBDIR
}
..."
echo
"Downloading NVSHMEM
${
NVSHMEM_VER
}
for
${
NVSHMEM_SUBDIR
}
..."
curl
-fSL
"
${
NVSHMEM_URL
}
"
-o
"
${
NVSHMEM_FILE
}
"
curl
-fSL
"
${
NVSHMEM_URL
}
"
-o
"
${
NVSHMEM_FILE
}
"
tar
-xf
"
${
NVSHMEM_FILE
}
"
tar
-xf
"
${
NVSHMEM_FILE
}
"
rm
-rf
nvshmem
mv
"
${
NVSHMEM_FILE
%.tar.xz
}
"
nvshmem
mv
"
${
NVSHMEM_FILE
%.tar.xz
}
"
nvshmem
rm
-f
"
${
NVSHMEM_FILE
}
"
rm
-f
"
${
NVSHMEM_FILE
}
"
rm
-rf
nvshmem/lib/bin nvshmem/lib/share
rm
-rf
nvshmem/lib/bin nvshmem/lib/share
...
...
vllm/distributed/elastic_ep/elastic_execute.py
View file @
7cf56a59
...
@@ -410,8 +410,7 @@ class ElasticEPScalingExecutor:
...
@@ -410,8 +410,7 @@ class ElasticEPScalingExecutor:
# for the new EP size by resetting quant_method to base
# for the new EP size by resetting quant_method to base
for
module
in
moe_modules
:
for
module
in
moe_modules
:
if
hasattr
(
module
.
quant_method
,
"old_quant_method"
):
if
hasattr
(
module
.
quant_method
,
"old_quant_method"
):
module
.
quant_method
=
module
.
quant_method
.
old_quant_method
module
.
_replace_quant_method
(
module
.
quant_method
.
old_quant_method
)
module
.
runner
=
module
.
_init_runner
()
prepare_communication_buffer_for_model
(
self
.
worker
.
model_runner
.
model
)
prepare_communication_buffer_for_model
(
self
.
worker
.
model_runner
.
model
)
eplb_model_state
.
communicator
=
create_eplb_communicator
(
eplb_model_state
.
communicator
=
create_eplb_communicator
(
...
...
vllm/lora/layers/fused_moe.py
View file @
7cf56a59
...
@@ -595,10 +595,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
...
@@ -595,10 +595,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
def
maybe_all_reduce_tensor_model_parallel
(
self
,
*
args
,
**
kwargs
):
def
maybe_all_reduce_tensor_model_parallel
(
self
,
*
args
,
**
kwargs
):
return
self
.
base_layer
.
maybe_all_reduce_tensor_model_parallel
(
*
args
,
**
kwargs
)
return
self
.
base_layer
.
maybe_all_reduce_tensor_model_parallel
(
*
args
,
**
kwargs
)
@
property
def
_shared_experts
(
self
):
return
self
.
base_layer
.
_shared_experts
@
property
@
property
def
quant_method
(
self
):
def
quant_method
(
self
):
return
self
.
base_layer
.
quant_method
return
self
.
base_layer
.
quant_method
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
7cf56a59
...
@@ -937,6 +937,15 @@ class FusedMoEParallelConfig:
...
@@ -937,6 +937,15 @@ class FusedMoEParallelConfig:
all2all_backend
:
str
# all2all backend for MoE communication
all2all_backend
:
str
# all2all backend for MoE communication
enable_eplb
:
bool
# whether to enable expert load balancing
enable_eplb
:
bool
# whether to enable expert load balancing
@
property
def
use_dp_chunking
(
self
)
->
bool
:
return
(
self
.
use_deepep_ll_kernels
or
self
.
use_mori_kernels
or
self
.
use_fi_nvl_two_sided_kernels
or
self
.
use_nixl_ep_kernels
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
@
property
@
property
def
is_sequence_parallel
(
self
)
->
bool
:
def
is_sequence_parallel
(
self
)
->
bool
:
return
self
.
sp_size
>
1
return
self
.
sp_size
>
1
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
7cf56a59
...
@@ -1194,6 +1194,8 @@ def cutlass_moe_w4a8_fp8(
...
@@ -1194,6 +1194,8 @@ def cutlass_moe_w4a8_fp8(
quant_config
=
quant_config
,
quant_config
=
quant_config
,
group_size
=
group_size
,
group_size
=
group_size
,
),
),
shared_experts
=
None
,
inplace
=
False
,
)
)
return
fn
.
apply
(
return
fn
.
apply
(
...
...
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
View file @
7cf56a59
...
@@ -53,6 +53,7 @@ class TrtLlmFp8ExpertsBase:
...
@@ -53,6 +53,7 @@ class TrtLlmFp8ExpertsBase:
self
.
local_num_experts
=
moe_config
.
num_local_experts
self
.
local_num_experts
=
moe_config
.
num_local_experts
self
.
ep_rank
=
moe_config
.
moe_parallel_config
.
ep_rank
self
.
ep_rank
=
moe_config
.
moe_parallel_config
.
ep_rank
self
.
moe_config
=
moe_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
@
staticmethod
@
staticmethod
...
...
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
7cf56a59
...
@@ -40,9 +40,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -40,9 +40,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def
mk_owns_shared_expert
(
self
)
->
bool
:
def
mk_owns_shared_expert
(
self
)
->
bool
:
# NOTE(rob): temporary attribute to indicate support for
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
# completed migration to the new internal MK interface.
return
(
return
self
.
moe_kernel
is
not
None
and
self
.
moe_kernel
.
owns_shared_experts
self
.
moe_kernel
is
not
None
and
self
.
moe_kernel
.
shared_experts
is
not
None
)
@
abstractmethod
@
abstractmethod
def
create_weights
(
def
create_weights
(
...
@@ -163,7 +161,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -163,7 +161,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
apply_monolithic
(
def
apply_monolithic
(
...
@@ -171,5 +169,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -171,5 +169,5 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
layer
:
"FusedMoE"
,
# type: ignore[name-defined] # noqa: F821
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
7cf56a59
...
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEKernel
,
FusedMoEKernel
,
FusedMoEPrepareAndFinalizeModular
,
FusedMoEPrepareAndFinalizeModular
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -44,7 +47,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
...
@@ -44,7 +47,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
moe_layer
:
torch
.
nn
.
Module
,
moe_layer
:
torch
.
nn
.
Module
,
old_quant_method
:
FusedMoEMethodBase
,
old_quant_method
:
FusedMoEMethodBase
,
prepare_finalize
:
FusedMoEPrepareAndFinalizeModular
,
prepare_finalize
:
FusedMoEPrepareAndFinalizeModular
,
shared_experts
:
torch
.
nn
.
Module
|
None
,
shared_experts
:
SharedExperts
|
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
"FusedMoEModularMethod"
:
)
->
"FusedMoEModularMethod"
:
return
FusedMoEModularMethod
(
return
FusedMoEModularMethod
(
...
@@ -52,8 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
...
@@ -52,8 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
FusedMoEKernel
(
FusedMoEKernel
(
prepare_finalize
,
prepare_finalize
,
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
shared_experts
,
shared_experts
=
shared_experts
,
moe_parallel_config
=
moe_layer
.
moe_parallel_config
,
inplace
=
inplace
,
inplace
=
inplace
,
),
),
)
)
...
@@ -89,7 +91,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
...
@@ -89,7 +91,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
assert
self
.
moe_kernel
is
not
None
assert
self
.
moe_kernel
is
not
None
return
self
.
moe_kernel
.
apply
(
return
self
.
moe_kernel
.
apply
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
7cf56a59
...
@@ -42,6 +42,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
...
@@ -42,6 +42,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
from
vllm.model_executor.layers.fused_moe.runner.default_moe_runner
import
(
from
vllm.model_executor.layers.fused_moe.runner.default_moe_runner
import
(
DefaultMoERunner
,
DefaultMoERunner
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
UnquantizedFusedMoEMethod
,
)
)
...
@@ -275,8 +278,6 @@ class FusedMoE(CustomOp):
...
@@ -275,8 +278,6 @@ class FusedMoE(CustomOp):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
_gate
=
gate
self
.
_shared_experts
=
shared_experts
self
.
_routed_input_transform
=
routed_input_transform
self
.
_routed_input_transform
=
routed_input_transform
if
params_dtype
is
None
:
if
params_dtype
is
None
:
...
@@ -486,7 +487,7 @@ class FusedMoE(CustomOp):
...
@@ -486,7 +487,7 @@ class FusedMoE(CustomOp):
device
=
vllm_config
.
device_config
.
device
,
device
=
vllm_config
.
device_config
.
device
,
routing_method
=
self
.
routing_method_type
,
routing_method
=
self
.
routing_method_type
,
# TODO: in_dtype == out_dtype?
# TODO: in_dtype == out_dtype?
disable_inplace
=
disable_inplace
()
or
self
.
_
shared_experts
is
not
None
,
disable_inplace
=
disable_inplace
()
or
shared_experts
is
not
None
,
)
)
if
self
.
moe_config
.
use_mori_kernels
:
if
self
.
moe_config
.
use_mori_kernels
:
assert
self
.
rocm_aiter_fmoe_enabled
,
(
assert
self
.
rocm_aiter_fmoe_enabled
,
(
...
@@ -564,34 +565,20 @@ class FusedMoE(CustomOp):
...
@@ -564,34 +565,20 @@ class FusedMoE(CustomOp):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
base_quant_method
=
self
.
quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
# - we are using marlin kernels
backend
=
self
.
moe_parallel_config
.
all2all_backend
self
.
use_overlapped
=
(
not
(
(
self
.
enable_eplb
and
backend
!=
"allgather_reducescatter"
)
or
self
.
moe_parallel_config
.
use_fi_nvl_two_sided_kernels
)
and
self
.
_shared_experts
is
not
None
)
self
.
runner
=
self
.
_init_runner
()
# TODO(bnell): this is un-needed and removed in a follow up PR.
self
.
base_quant_method
=
self
.
quant_method
def
_init_runner
(
self
):
# Storing the runner in the FusedMoE is an intermediate state, eventually
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
# for MoE ops.
return
DefaultMoERunner
(
self
.
runner
=
DefaultMoERunner
(
layer
=
self
,
layer
=
self
,
moe_config
=
self
.
moe_config
,
moe_config
=
self
.
moe_config
,
router
=
self
.
router
,
router
=
self
.
router
,
routed_input_transform
=
self
.
_routed_input_transform
,
routed_input_transform
=
self
.
_routed_input_transform
,
gate
=
self
.
gate
,
gate
=
gate
,
shared_experts
=
self
.
shared_experts
,
shared_experts
=
shared_experts
,
quant_method
=
self
.
quant_method
,
quant_method
=
self
.
quant_method
,
reduce_results
=
self
.
reduce_results
,
reduce_results
=
self
.
reduce_results
,
enable_dbo
=
self
.
vllm_config
.
parallel_config
.
enable_dbo
,
enable_dbo
=
self
.
vllm_config
.
parallel_config
.
enable_dbo
,
...
@@ -602,10 +589,7 @@ class FusedMoE(CustomOp):
...
@@ -602,10 +589,7 @@ class FusedMoE(CustomOp):
# intrusive way to do this.
# intrusive way to do this.
def
_replace_quant_method
(
self
,
mk
:
FusedMoEMethodBase
):
def
_replace_quant_method
(
self
,
mk
:
FusedMoEMethodBase
):
self
.
quant_method
=
mk
self
.
quant_method
=
mk
# We need to force reconstruction of runner because we're swapping out
self
.
runner
.
_replace_quant_method
(
mk
)
# the quant_method with a FusedMoEModularMethod. This logic can go
# away once the FusedMoEModularMethod is eliminated.
self
.
runner
=
self
.
_init_runner
()
# Note: maybe_init_modular_kernel should only be called by
# Note: maybe_init_modular_kernel should only be called by
# prepare_communication_buffer_for_model.
# prepare_communication_buffer_for_model.
...
@@ -639,8 +623,8 @@ class FusedMoE(CustomOp):
...
@@ -639,8 +623,8 @@ class FusedMoE(CustomOp):
)
)
@
property
@
property
def
shared_experts
(
self
)
->
torch
.
nn
.
Module
|
None
:
def
shared_experts
(
self
)
->
SharedExperts
|
None
:
return
self
.
_
shared_experts
if
self
.
use_overlapped
else
None
return
self
.
runner
.
shared_experts
@
property
@
property
def
layer_id
(
self
):
def
layer_id
(
self
):
...
@@ -649,10 +633,6 @@ class FusedMoE(CustomOp):
...
@@ -649,10 +633,6 @@ class FusedMoE(CustomOp):
return
extract_layer_index
(
self
.
layer_name
)
return
extract_layer_index
(
self
.
layer_name
)
@
property
def
gate
(
self
)
->
torch
.
nn
.
Module
|
None
:
return
self
.
_gate
if
self
.
use_overlapped
else
None
@
property
@
property
def
tp_size
(
self
):
def
tp_size
(
self
):
return
self
.
moe_parallel_config
.
tp_size
return
self
.
moe_parallel_config
.
tp_size
...
@@ -676,7 +656,7 @@ class FusedMoE(CustomOp):
...
@@ -676,7 +656,7 @@ class FusedMoE(CustomOp):
@
property
@
property
def
is_internal_router
(
self
)
->
bool
:
def
is_internal_router
(
self
)
->
bool
:
# By default, router/gate is called before FusedMoE forward pass
# By default, router/gate is called before FusedMoE forward pass
return
self
.
gate
is
not
None
return
self
.
runner
.
is_internal_router
()
def
_maybe_init_expert_routing_tables
(
def
_maybe_init_expert_routing_tables
(
self
,
self
,
...
@@ -1467,7 +1447,12 @@ class FusedMoE(CustomOp):
...
@@ -1467,7 +1447,12 @@ class FusedMoE(CustomOp):
assert
all
(
assert
all
(
weight
.
is_contiguous
()
weight
.
is_contiguous
()
for
name
,
weight
in
weights
for
name
,
weight
in
weights
if
not
(
name
.
startswith
(
"_shared_experts."
)
or
name
.
startswith
(
"_gate."
))
if
not
(
name
.
startswith
(
"_shared_experts."
)
or
name
.
startswith
(
"_gate."
)
or
name
.
startswith
(
"_routed_input_transform."
)
or
name
.
startswith
(
"_routed_output_transform."
)
)
and
name
not
in
NON_EXPERT_WEIGHTS
and
name
not
in
NON_EXPERT_WEIGHTS
)
)
...
@@ -1477,8 +1462,11 @@ class FusedMoE(CustomOp):
...
@@ -1477,8 +1462,11 @@ class FusedMoE(CustomOp):
if
name
not
in
NON_EXPERT_WEIGHTS
if
name
not
in
NON_EXPERT_WEIGHTS
and
weight
.
shape
!=
torch
.
Size
([])
and
weight
.
shape
!=
torch
.
Size
([])
and
not
name
.
startswith
(
"_shared_experts."
)
and
not
name
.
startswith
(
"_shared_experts."
)
# exclude parameters from non-expert submodules (e.g. gate/shared)
# exclude parameters from non-expert submodules,
# e.g. gate/shared/transforms.
and
not
name
.
startswith
(
"_gate."
)
and
not
name
.
startswith
(
"_gate."
)
and
not
name
.
startswith
(
"_routed_input_transform."
)
and
not
name
.
startswith
(
"_routed_output_transform."
)
]
]
def
set_eplb_state
(
def
set_eplb_state
(
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
7cf56a59
...
@@ -21,6 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -21,6 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
RoutingMethodType
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
SharedExpertsOrder
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
_resize_cache
,
disable_inplace
,
disable_inplace
,
...
@@ -235,6 +239,13 @@ class FusedMoEPrepareAndFinalize(ABC):
...
@@ -235,6 +239,13 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
supports_async
(
self
)
->
bool
:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return
False
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class
FusedMoEPrepareAndFinalizeModular
(
FusedMoEPrepareAndFinalize
):
class
FusedMoEPrepareAndFinalizeModular
(
FusedMoEPrepareAndFinalize
):
...
@@ -281,13 +292,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
...
@@ -281,13 +292,6 @@ class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
supports_async
(
self
)
->
bool
:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return
False
def
prepare_async
(
def
prepare_async
(
self
,
self
,
a1
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
...
@@ -1003,15 +1007,20 @@ class FusedMoEKernelModularImpl:
...
@@ -1003,15 +1007,20 @@ class FusedMoEKernelModularImpl:
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalizeModular
,
prepare_finalize
:
FusedMoEPrepareAndFinalizeModular
,
fused_experts
:
FusedMoEExpertsModular
,
fused_experts
:
FusedMoEExpertsModular
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
,
moe_parallel_config
:
FusedMoEParallelConfig
|
None
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
):
):
self
.
prepare_finalize
=
prepare_finalize
self
.
prepare_finalize
=
prepare_finalize
self
.
fused_experts
=
fused_experts
self
.
fused_experts
=
fused_experts
self
.
shared_experts
=
shared_experts
# Only accept shared experts if they can be run w/async.
self
.
moe_parallel_config
=
moe_parallel_config
# The MoERunner/SharedExperts class will coordinate with the MK to ensure
# that the SharedExperts are executed only once.
self
.
shared_experts
=
(
shared_experts
if
prepare_finalize
.
supports_async
()
else
None
)
self
.
inplace
=
inplace
self
.
inplace
=
inplace
moe_parallel_config
=
fused_experts
.
moe_config
.
moe_parallel_config
self
.
moe_parallel_config
=
moe_parallel_config
self
.
is_dp_ep
=
(
self
.
is_dp_ep
=
(
moe_parallel_config
is
not
None
moe_parallel_config
is
not
None
and
moe_parallel_config
.
dp_size
>
1
and
moe_parallel_config
.
dp_size
>
1
...
@@ -1081,6 +1090,17 @@ class FusedMoEKernelModularImpl:
...
@@ -1081,6 +1090,17 @@ class FusedMoEKernelModularImpl:
return
workspace13
,
workspace2
,
fused_out
return
workspace13
,
workspace2
,
fused_out
def
_maybe_apply_shared_experts
(
self
,
shared_experts_input
:
torch
.
Tensor
|
None
,
):
if
self
.
shared_experts
is
not
None
:
assert
shared_experts_input
is
not
None
self
.
shared_experts
.
apply
(
shared_experts_input
,
SharedExpertsOrder
.
MK_INTERNAL_OVERLAPPED
,
)
def
_prepare
(
def
_prepare
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -1253,15 +1273,6 @@ class FusedMoEKernelModularImpl:
...
@@ -1253,15 +1273,6 @@ class FusedMoEKernelModularImpl:
shared_experts_input is the original hidden_states (full
shared_experts_input is the original hidden_states (full
dimension) needed by the shared expert MLP.
dimension) needed by the shared expert MLP.
"""
"""
shared_output
:
torch
.
Tensor
|
None
=
None
# For latent MoE: shared experts need the original hidden_states
# (full hidden_size), not the latent-projected version used by
# routed experts.
se_hidden_states
=
(
shared_experts_input
if
shared_experts_input
is
not
None
else
hidden_states
)
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
assert
not
dbo_enabled
()
assert
not
dbo_enabled
()
...
@@ -1273,8 +1284,6 @@ class FusedMoEKernelModularImpl:
...
@@ -1273,8 +1284,6 @@ class FusedMoEKernelModularImpl:
apply_router_weight_on_input
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
se_hidden_states
)
else
:
else
:
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
output
,
...
@@ -1284,8 +1293,7 @@ class FusedMoEKernelModularImpl:
...
@@ -1284,8 +1293,7 @@ class FusedMoEKernelModularImpl:
apply_router_weight_on_input
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
)
if
self
.
shared_experts
is
not
None
:
self
.
_maybe_apply_shared_experts
(
shared_experts_input
)
shared_output
=
self
.
shared_experts
(
se_hidden_states
)
# TODO(lucas): refactor this in the alternative schedules followup
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# currently unpack if we have hook + receiver pair or just
...
@@ -1308,11 +1316,7 @@ class FusedMoEKernelModularImpl:
...
@@ -1308,11 +1316,7 @@ class FusedMoEKernelModularImpl:
receiver
()
receiver
()
if
self
.
shared_experts
is
None
:
return
output
return
output
else
:
assert
shared_output
is
not
None
return
shared_output
,
output
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1326,7 +1330,7 @@ class FusedMoEKernelModularImpl:
...
@@ -1326,7 +1330,7 @@ class FusedMoEKernelModularImpl:
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
shared_experts_input
:
torch
.
Tensor
|
None
=
None
,
shared_experts_input
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
of weights, w1 and w2, and top-k gating mechanism.
...
@@ -1469,12 +1473,10 @@ class FusedMoEKernel:
...
@@ -1469,12 +1473,10 @@ class FusedMoEKernel:
self
,
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
fused_experts
:
FusedMoEExperts
,
fused_experts
:
FusedMoEExperts
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
=
None
,
moe_parallel_config
:
FusedMoEParallelConfig
|
None
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
shared_experts
=
shared_experts
# NOTE: check if we can remove
# Initialize the implementation (monolithic or modular).
# Initialize the implementation (monolithic or modular).
self
.
impl
:
FusedMoEKernelModularImpl
|
FusedMoEKernelMonolithicImpl
self
.
impl
:
FusedMoEKernelModularImpl
|
FusedMoEKernelMonolithicImpl
...
@@ -1485,14 +1487,12 @@ class FusedMoEKernel:
...
@@ -1485,14 +1487,12 @@ class FusedMoEKernel:
prepare_finalize
,
prepare_finalize
,
fused_experts
,
fused_experts
,
shared_experts
,
shared_experts
,
moe_parallel_config
,
inplace
,
inplace
,
)
)
elif
isinstance
(
elif
isinstance
(
prepare_finalize
,
FusedMoEPrepareAndFinalizeMonolithic
prepare_finalize
,
FusedMoEPrepareAndFinalizeMonolithic
)
and
isinstance
(
fused_experts
,
FusedMoEExpertsMonolithic
):
)
and
isinstance
(
fused_experts
,
FusedMoEExpertsMonolithic
):
assert
shared_experts
is
None
assert
not
inplace
assert
not
inplace
self
.
impl
=
FusedMoEKernelMonolithicImpl
(
self
.
impl
=
FusedMoEKernelMonolithicImpl
(
prepare_finalize
,
prepare_finalize
,
...
@@ -1508,6 +1508,13 @@ class FusedMoEKernel:
...
@@ -1508,6 +1508,13 @@ class FusedMoEKernel:
self
.
_post_init_setup
()
self
.
_post_init_setup
()
@
property
def
owns_shared_experts
(
self
)
->
bool
:
if
isinstance
(
self
.
impl
,
FusedMoEKernelModularImpl
):
return
self
.
impl
.
shared_experts
is
not
None
else
:
return
False
@
property
@
property
def
is_monolithic
(
self
)
->
bool
:
def
is_monolithic
(
self
)
->
bool
:
return
isinstance
(
self
.
impl
,
FusedMoEKernelMonolithicImpl
)
return
isinstance
(
self
.
impl
,
FusedMoEKernelMonolithicImpl
)
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
7cf56a59
...
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
FlashinferMoeBackend
,
get_flashinfer_moe_backend
,
get_flashinfer_moe_backend
,
...
@@ -545,7 +548,7 @@ def make_fp8_moe_kernel(
...
@@ -545,7 +548,7 @@ def make_fp8_moe_kernel(
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
fp8_backend
:
Fp8MoeBackend
,
fp8_backend
:
Fp8MoeBackend
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
=
None
,
)
->
mk
.
FusedMoEKernel
:
)
->
mk
.
FusedMoEKernel
:
# Create Prepare/Finalize.
# Create Prepare/Finalize.
prepare_finalize
=
maybe_make_prepare_finalize
(
prepare_finalize
=
maybe_make_prepare_finalize
(
...
@@ -581,12 +584,7 @@ def make_fp8_moe_kernel(
...
@@ -581,12 +584,7 @@ def make_fp8_moe_kernel(
kernel
=
mk
.
FusedMoEKernel
(
kernel
=
mk
.
FusedMoEKernel
(
prepare_finalize
,
prepare_finalize
,
experts
,
experts
,
shared_experts
=
(
shared_experts
=
shared_experts
,
shared_experts
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
else
None
),
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
inplace
=
(
inplace
=
(
not
moe_config
.
disable_inplace
not
moe_config
.
disable_inplace
and
fp8_backend
!=
Fp8MoeBackend
.
FLASHINFER_CUTLASS
and
fp8_backend
!=
Fp8MoeBackend
.
FLASHINFER_CUTLASS
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
View file @
7cf56a59
...
@@ -859,7 +859,6 @@ def make_mxfp4_moe_kernel(
...
@@ -859,7 +859,6 @@ def make_mxfp4_moe_kernel(
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
else
None
else
None
),
),
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
inplace
=
(
inplace
=
(
not
moe_config
.
disable_inplace
and
mxfp4_backend
not
in
TRTLLM_BACKENDS
not
moe_config
.
disable_inplace
and
mxfp4_backend
not
in
TRTLLM_BACKENDS
),
),
...
...
vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
View file @
7cf56a59
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config
,
nvfp4_moe_quant_config
,
nvfp4_w4a16_moe_quant_config
,
nvfp4_w4a16_moe_quant_config
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
prepare_nvfp4_moe_layer_for_fi_or_cutlass
,
prepare_nvfp4_moe_layer_for_fi_or_cutlass
,
)
)
...
@@ -386,7 +389,7 @@ def make_nvfp4_moe_kernel(
...
@@ -386,7 +389,7 @@ def make_nvfp4_moe_kernel(
moe_config
:
FusedMoEConfig
,
moe_config
:
FusedMoEConfig
,
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
=
None
,
)
->
mk
.
FusedMoEKernel
:
)
->
mk
.
FusedMoEKernel
:
# Create Prepare/Finalize.
# Create Prepare/Finalize.
prepare_finalize
=
maybe_make_prepare_finalize
(
prepare_finalize
=
maybe_make_prepare_finalize
(
...
@@ -422,12 +425,7 @@ def make_nvfp4_moe_kernel(
...
@@ -422,12 +425,7 @@ def make_nvfp4_moe_kernel(
kernel
=
mk
.
FusedMoEKernel
(
kernel
=
mk
.
FusedMoEKernel
(
prepare_finalize
,
prepare_finalize
,
experts
,
experts
,
shared_experts
=
(
shared_experts
=
shared_experts
,
shared_experts
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
else
None
),
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
inplace
=
False
,
inplace
=
False
,
)
)
...
...
vllm/model_executor/layers/fused_moe/oracle/unquantized.py
View file @
7cf56a59
...
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
FlashinferMoeBackend
,
convert_moe_weights_to_flashinfer_trtllm_block_layout
,
convert_moe_weights_to_flashinfer_trtllm_block_layout
,
...
@@ -321,7 +324,7 @@ def make_unquantized_moe_kernel(
...
@@ -321,7 +324,7 @@ def make_unquantized_moe_kernel(
backend
:
UnquantizedMoeBackend
,
backend
:
UnquantizedMoeBackend
,
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
experts_cls
:
type
[
mk
.
FusedMoEExperts
],
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
shared_experts
:
SharedExperts
|
None
=
None
,
)
->
mk
.
FusedMoEKernel
:
)
->
mk
.
FusedMoEKernel
:
# Create Prepare/Finalize
# Create Prepare/Finalize
is_monolithic
=
issubclass
(
experts_cls
,
mk
.
FusedMoEExpertsMonolithic
)
is_monolithic
=
issubclass
(
experts_cls
,
mk
.
FusedMoEExpertsMonolithic
)
...
@@ -355,12 +358,7 @@ def make_unquantized_moe_kernel(
...
@@ -355,12 +358,7 @@ def make_unquantized_moe_kernel(
kernel
=
mk
.
FusedMoEKernel
(
kernel
=
mk
.
FusedMoEKernel
(
prepare_finalize
,
prepare_finalize
,
experts
,
experts
,
shared_experts
=
(
shared_experts
=
shared_experts
,
shared_experts
if
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
else
None
),
moe_parallel_config
=
moe_config
.
moe_parallel_config
,
inplace
=
(
not
moe_config
.
disable_inplace
and
not
is_monolithic
),
inplace
=
(
not
moe_config
.
disable_inplace
and
not
is_monolithic
),
)
)
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize/deepep_ll.py
View file @
7cf56a59
...
@@ -325,7 +325,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
...
@@ -325,7 +325,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
**
(
dict
(
x_global_scale
=
qc_a1_gscale_or_scale
)
dict
(
x_global_scale
=
qc_a1_gscale_or_scale
)
if
qc_a1_gscale_or_scale
is
not
None
if
qc_a1_gscale_or_scale
is
not
None
and
nvfp4_dispatch
else
dict
()
else
dict
()
),
),
async_finish
=
False
,
async_finish
=
False
,
...
...
vllm/model_executor/layers/fused_moe/router/fused_moe_router.py
View file @
7cf56a59
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
import
torch
import
torch
...
@@ -13,6 +14,13 @@ class FusedMoERouter(ABC):
...
@@ -13,6 +14,13 @@ class FusedMoERouter(ABC):
method that is used for routing hidden states based on router logits.
method that is used for routing hidden states based on router logits.
"""
"""
@
abstractmethod
def
set_capture_fn
(
self
,
capture_fn
:
Callable
[[
torch
.
Tensor
],
None
]
|
None
,
)
->
None
:
raise
NotImplementedError
@
property
@
property
@
abstractmethod
@
abstractmethod
def
routing_method_type
(
self
)
->
RoutingMethodType
:
def
routing_method_type
(
self
)
->
RoutingMethodType
:
...
...
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
View file @
7cf56a59
...
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING
...
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_ep_group
,
get_ep_group
,
get_pcp_group
,
get_pcp_group
,
...
@@ -29,13 +28,15 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
...
@@ -29,13 +28,15 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter
,
FusedMoERouter
,
)
)
from
vllm.model_executor.layers.fused_moe.runner.moe_runner
import
MoERunner
from
vllm.model_executor.layers.fused_moe.runner.moe_runner
import
MoERunner
from
vllm.model_executor.layers.fused_moe.runner.shared_experts
import
(
SharedExperts
,
SharedExpertsOrder
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
(
from
vllm.utils.torch_utils
import
(
HAS_OPAQUE_TYPE
,
HAS_OPAQUE_TYPE
,
ModuleName
,
ModuleName
,
aux_stream
,
current_stream
,
direct_register_custom_op
,
direct_register_custom_op
,
)
)
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
...
@@ -74,6 +75,9 @@ def _resolve_layer_name(layer_name: str | ModuleName) -> str:
...
@@ -74,6 +75,9 @@ def _resolve_layer_name(layer_name: str | ModuleName) -> str:
return
layer_name
.
value
if
isinstance
(
layer_name
,
ModuleName
)
else
layer_name
return
layer_name
.
value
if
isinstance
(
layer_name
,
ModuleName
)
else
layer_name
# Note: _moe_forward and _moe_forward_shared should not contain any
# implementation details, They should merely pass along control to
# the runner's 'forward_dispatch' method.
def
_moe_forward
(
def
_moe_forward
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
@@ -81,19 +85,7 @@ def _moe_forward(
...
@@ -81,19 +85,7 @@ def _moe_forward(
layer_name
:
_layer_name_type
,
layer_name
:
_layer_name_type
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
return
layer
.
runner
.
forward_dispatch
(
layer
.
ensure_moe_quant_config_init
()
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
layer
,
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
...
@@ -117,19 +109,7 @@ def _moe_forward_shared(
...
@@ -117,19 +109,7 @@ def _moe_forward_shared(
layer_name
:
_layer_name_type
,
layer_name
:
_layer_name_type
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
return
layer
.
runner
.
forward_dispatch
(
layer
.
ensure_moe_quant_config_init
()
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
layer
,
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
...
@@ -159,7 +139,7 @@ def _moe_forward_shared_fake(
...
@@ -159,7 +139,7 @@ def _moe_forward_shared_fake(
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"moe_forward"
,
op_name
=
"moe_forward"
,
op_func
=
_moe_forward
,
op_func
=
_moe_forward
,
mutates_args
=
[
"hidden_states"
],
mutates_args
=
[
"hidden_states"
],
# is this still true?
fake_impl
=
_moe_forward_fake
,
fake_impl
=
_moe_forward_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
)
...
@@ -168,7 +148,6 @@ direct_register_custom_op(
...
@@ -168,7 +148,6 @@ direct_register_custom_op(
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"moe_forward_shared"
,
op_name
=
"moe_forward_shared"
,
op_func
=
_moe_forward_shared
,
op_func
=
_moe_forward_shared
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
_moe_forward_shared_fake
,
fake_impl
=
_moe_forward_shared_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
)
...
@@ -213,87 +192,68 @@ class DefaultMoERunner(MoERunner):
...
@@ -213,87 +192,68 @@ class DefaultMoERunner(MoERunner):
self
.
router
=
router
self
.
router
=
router
self
.
routed_input_transform
=
routed_input_transform
self
.
routed_input_transform
=
routed_input_transform
self
.
gate
=
gate
self
.
gate
=
gate
self
.
shared_experts
=
shared_experts
self
.
quant_method
=
quant_method
self
.
quant_method
=
quant_method
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
enable_dbo
=
enable_dbo
self
.
enable_dbo
=
enable_dbo
self
.
shared_experts
:
SharedExperts
|
None
=
None
if
shared_experts
is
not
None
:
self
.
shared_experts
=
SharedExperts
(
shared_experts
,
moe_config
=
moe_config
,
# Note: For now we must pass quant_method along to SharedExperts so it
# can property determine where the shared experts are supposed to be
# called, i.e. by a MK or by the MoERunner.
# Once the MK can be created upfront, we can just pass in the proper
# flags derived from the quant_method's MK.
reduce_results
=
reduce_results
,
quant_method
=
quant_method
,
enable_dbo
=
enable_dbo
,
)
# Chunked all2all staging tensor
# Chunked all2all staging tensor
# TODO(bnell) rename these?
# These need to exist ahead of time due to CUDAgraph construction
# needing a fixed buffer address.
self
.
use_dp_chunking
=
self
.
moe_config
.
moe_parallel_config
.
use_dp_chunking
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
_maybe_init_dp_chunking
()
self
.
_maybe_init_dp_chunking
()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
self
.
use_shared_experts_stream
=
False
if
envs
.
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
logger
.
debug_once
(
"Disabling MoE shared_experts cuda stream"
,
scope
=
"local"
)
self
.
shared_experts_stream
=
None
else
:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self
.
shared_experts_stream
=
aux_stream
()
if
self
.
shared_experts_stream
is
not
None
:
logger
.
debug_once
(
"Enabled separate cuda stream for MoE shared_experts"
,
scope
=
"local"
)
# Needed for string -> FusedMoE layer lookup in custom ops.
# Needed for string -> FusedMoE layer lookup in custom ops.
self
.
layer_name
=
layer
.
layer_name
self
.
layer_name
=
layer
.
layer_name
self
.
moe_forward
=
self
.
_select_forward
(
layer
)
self
.
forward_entry
,
self
.
forward_impl
=
self
.
_select_forward
(
layer
)
def
_select_forward
(
self
,
layer
:
torch
.
nn
.
Module
)
->
tuple
[
Callable
,
Callable
]:
# Select implementation based on presence of DP chunking.
forward_impl_fn
=
(
self
.
_forward_impl_chunked
if
self
.
use_dp_chunking
else
self
.
_forward_impl
)
def
_select_forward
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Callable
:
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
# Note: CPU doesn't require wrapped forward_impl.
return
_moe_forward
if
self
.
shared_experts
is
None
else
_moe_forward_shared
return
(
return
(
torch
.
ops
.
vllm
.
moe_forward
_moe_forward
if
self
.
shared_experts
is
None
else
_moe_forward_shared
,
if
self
.
shared_experts
is
None
forward_impl_fn
,
else
torch
.
ops
.
vllm
.
moe_forward_shared
)
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
return
(
return
(
self
.
moe_config
.
moe_parallel_config
.
use_deepep_ll_kernels
torch
.
ops
.
vllm
.
moe_forward
or
self
.
moe_config
.
moe_parallel_config
.
use_mori_kernels
if
self
.
shared_experts
is
None
or
self
.
moe_config
.
moe_parallel_config
.
use_fi_nvl_two_sided_kernels
else
torch
.
ops
.
vllm
.
moe_forward_shared
,
or
self
.
moe_config
.
moe_parallel_config
.
use_nixl_ep_kernels
forward_impl_fn
,
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
def
_maybe_setup_shared_experts_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
):
if
self
.
use_shared_experts_stream
:
assert
self
.
shared_experts_stream
is
not
None
assert
self
.
moe_config
.
disable_inplace
shared_experts_input
=
(
shared_input
if
shared_input
is
not
None
else
hidden_states
)
)
# Record that the shared_experts_input will be used in the
# TODO(bnell): temporary hack, do not call this method.
# shared_experts_stream to avoid gc issue from
def
_replace_quant_method
(
self
,
quant_method
:
FusedMoEMethodBase
):
# deallocation. For more details:
if
self
.
shared_experts
is
not
None
:
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
self
.
shared_experts
.
_quant_method
=
quant_method
# NOTE: We don't need shared_output.record_stream(current_stream())
self
.
quant_method
=
quant_method
# because we synch the streams before using shared_output.
shared_experts_input
.
record_stream
(
self
.
shared_experts_stream
)
# Mark sync start point for the separate shared experts
def
is_internal_router
(
self
)
->
bool
:
# stream here since we want to run in parallel with the
return
self
.
gate
is
not
None
# router/gate (next op below)
assert
self
.
shared_experts_stream
is
not
None
self
.
shared_experts_stream
.
wait_stream
(
current_stream
())
def
_maybe_init_dp_chunking
(
self
):
def
_maybe_init_dp_chunking
(
self
):
if
not
self
.
use_dp_chunking
:
if
not
self
.
use_dp_chunking
:
...
@@ -325,38 +285,6 @@ class DefaultMoERunner(MoERunner):
...
@@ -325,38 +285,6 @@ class DefaultMoERunner(MoERunner):
device
=
device
,
device
=
device
,
)
)
@
property
def
has_separate_shared_experts
(
self
)
->
bool
:
return
(
not
self
.
quant_method
.
mk_owns_shared_expert
and
self
.
shared_experts
is
not
None
)
def
_apply_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
,
allow_streaming
:
bool
=
False
,
)
->
torch
.
Tensor
|
None
:
shared_output
:
torch
.
Tensor
|
None
=
None
if
self
.
has_separate_shared_experts
:
assert
self
.
shared_experts
is
not
None
if
self
.
use_shared_experts_stream
and
allow_streaming
:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with
torch
.
cuda
.
stream
(
self
.
shared_experts_stream
):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output
=
self
.
shared_experts
(
hidden_states
)
current_stream
().
wait_stream
(
self
.
shared_experts_stream
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
return
shared_output
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
"""
"""
The shared_experts are typically computed using the RowParallelLinear
The shared_experts are typically computed using the RowParallelLinear
...
@@ -384,7 +312,9 @@ class DefaultMoERunner(MoERunner):
...
@@ -384,7 +312,9 @@ class DefaultMoERunner(MoERunner):
else
:
else
:
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
def
apply_routed_input_transform
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply_routed_input_transform
(
self
,
hidden_states
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""Apply transform for routed experts (e.g., latent projection).
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
This is called by FusedMoE.forward_native. The original hidden_states
...
@@ -394,15 +324,22 @@ class DefaultMoERunner(MoERunner):
...
@@ -394,15 +324,22 @@ class DefaultMoERunner(MoERunner):
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
dimension.
Returns (possibly transformed) hidden states and the input for shared
experts (or None if there are no shared experts).
"""
"""
if
self
.
routed_input_transform
is
not
None
:
if
self
.
routed_input_transform
is
not
None
:
result
=
self
.
routed_input_transform
(
hidden_states
)
result
=
self
.
routed_input_transform
(
hidden_states
)
# ReplicatedLinear returns (output, extra_bias) tuple.
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
# We only need the output tensor; extra_bias is not used here.
if
isinstance
(
result
,
tuple
):
if
isinstance
(
result
,
tuple
):
return
result
[
0
]
return
result
[
0
],
hidden_states
return
result
return
result
,
hidden_states
return
hidden_states
return
(
hidden_states
,
hidden_states
if
self
.
shared_experts
is
not
None
else
None
,
)
def
_maybe_reduce_output
(
def
_maybe_reduce_output
(
self
,
self
,
...
@@ -446,13 +383,11 @@ class DefaultMoERunner(MoERunner):
...
@@ -446,13 +383,11 @@ class DefaultMoERunner(MoERunner):
def
_maybe_pad_hidden_states
(
def
_maybe_pad_hidden_states
(
self
,
self
,
original_hidden_states
:
torch
.
Tensor
|
None
,
shared_experts_input
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
list
[
int
]]:
)
->
tuple
[
torch
.
Tensor
,
list
[
int
]]:
original_hidden_dim
=
(
shared_experts_hidden_dim
=
(
original_hidden_states
.
shape
[
-
1
]
shared_experts_input
.
shape
[
-
1
]
if
shared_experts_input
is
not
None
else
0
if
original_hidden_states
is
not
None
else
0
)
)
transformed_hidden_dim
=
hidden_states
.
shape
[
-
1
]
transformed_hidden_dim
=
hidden_states
.
shape
[
-
1
]
if
(
if
(
...
@@ -467,29 +402,37 @@ class DefaultMoERunner(MoERunner):
...
@@ -467,29 +402,37 @@ class DefaultMoERunner(MoERunner):
)
)
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
orig_hidden_dims
=
[
original
_hidden_dim
,
transformed_hidden_dim
]
orig_hidden_dims
=
[
shared_experts
_hidden_dim
,
transformed_hidden_dim
]
else
:
else
:
orig_hidden_dims
=
[
transformed_hidden_dim
]
orig_hidden_dims
=
[
transformed_hidden_dim
]
return
hidden_states
,
orig_hidden_dims
return
hidden_states
,
orig_hidden_dims
def
_maybe_apply_shared_experts
(
self
,
shared_experts_input
:
torch
.
Tensor
|
None
,
order
:
SharedExpertsOrder
,
):
if
self
.
shared_experts
is
not
None
:
assert
shared_experts_input
is
not
None
self
.
shared_experts
.
apply
(
shared_experts_input
,
order
)
def
_apply_quant_method
(
def
_apply_quant_method
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
shared_experts_input
:
torch
.
Tensor
|
None
,
run_shared_experts_before
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
shared_input
=
shared_input
if
shared_input
is
not
None
else
hidden_states
shared_output
:
torch
.
Tensor
|
None
=
None
# Run this before quant_method to avoid inplace issues.
# Run this before quant_method to avoid inplace issues.
if
run_shared_experts_before
:
# TODO(bnell): probably not needed anymore since inplace is
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
False
)
# disabled when shared experts are present.
self
.
_maybe_apply_shared_experts
(
shared_experts_input
,
SharedExpertsOrder
.
NO_OVERLAP
)
if
self
.
quant_method
.
is_monolithic
:
if
self
.
quant_method
.
is_monolithic
:
resul
t
=
self
.
quant_method
.
apply_monolithic
(
fused_ou
t
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
layer
=
layer
,
x
=
hidden_states
,
x
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -500,25 +443,25 @@ class DefaultMoERunner(MoERunner):
...
@@ -500,25 +443,25 @@ class DefaultMoERunner(MoERunner):
router_logits
=
router_logits
,
router_logits
=
router_logits
,
)
)
result
=
self
.
quant_method
.
apply
(
# Passing shared_experts_input in case SharedExpertsOrder is
# NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
fused_out
=
self
.
quant_method
.
apply
(
layer
=
layer
,
layer
=
layer
,
x
=
hidden_states
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
shared_experts_input
=
shared_
experts_
input
,
)
)
if
isinstance
(
result
,
tuple
):
self
.
_maybe_apply_shared_experts
(
assert
shared_output
is
None
shared_experts_input
,
shared_output
,
hidden_states
=
result
SharedExpertsOrder
.
MULTI_STREAM_OVERLAPPED
,
else
:
)
hidden_states
=
result
if
not
run_shared_experts_before
and
self
.
has_separate_shared_experts
:
assert
shared_output
is
None
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
True
)
return
shared_output
,
hidden_states
return
(
self
.
shared_experts
.
output
if
self
.
shared_experts
is
not
None
else
None
,
fused_out
,
)
def
_sequence_parallel_context
(
self
):
def
_sequence_parallel_context
(
self
):
ctx
=
get_forward_context
()
ctx
=
get_forward_context
()
...
@@ -558,18 +501,16 @@ class DefaultMoERunner(MoERunner):
...
@@ -558,18 +501,16 @@ class DefaultMoERunner(MoERunner):
return
final_shared_hidden_states
,
final_fused_hidden_states
return
final_shared_hidden_states
,
final_fused_hidden_states
def
_maybe_
gate
(
def
_maybe_
sync_shared_experts_stream
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
router_logits
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
# If router/gate provided, then apply it here.
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
# separate cuda stream)
if
self
.
gate
is
not
None
:
if
self
.
shared_experts
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
self
.
shared_experts
.
maybe_sync_shared_experts_stream
(
shared_experts_input
)
return
router_logits
@
property
@
property
def
do_naive_dispatch_combine
(
self
)
->
bool
:
def
do_naive_dispatch_combine
(
self
)
->
bool
:
...
@@ -624,7 +565,6 @@ class DefaultMoERunner(MoERunner):
...
@@ -624,7 +565,6 @@ class DefaultMoERunner(MoERunner):
hidden_states
,
hidden_states
,
dim
=
0
,
dim
=
0
,
)
)
# need RS for shared_output?
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
assert
shared_output
is
not
None
...
@@ -637,30 +577,86 @@ class DefaultMoERunner(MoERunner):
...
@@ -637,30 +577,86 @@ class DefaultMoERunner(MoERunner):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For latent MoE: save ORIGINAL hidden_states before transform
"""Invoke the fused moe layer.
# (shared_experts need original dimension, routed experts use transformed)
if
self
.
shared_experts
is
not
None
:
Input:
original_hidden_states
=
hidden_states
- hidden_states
else
:
- router_logits
original_hidden_states
=
None
Output:
- The new hidden_states.
or
- A tuple of (shared experts output, new hidden_states).
Calling sequence
- forward
- self.forward_entry (_moe_forward or _moe_forward_shared custom op)
- forward_dispatch
- forward_impl (_forward_impl or _forward_impl_chunked)
Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
to the following reasons:
1. the chunking loop in _forward_impl_chunked cannot be compiled by
torch.compile
2. pytorch cannot handle union types in custom op signatures so _moe_forward
and _moe_forward_shared must be split.
If _forward_impl_chunked can be implemented via torch.scan we can potentially
get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
into the 'forward' method.
"""
# Apply transform for routed experts (e.g., latent projection for latent MoE)
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
hidden_states
,
shared_experts_input
=
self
.
apply_routed_input_transform
(
hidden_states
)
hidden_states
,
og_hidden_dims
=
self
.
_maybe_pad_hidden_states
(
hidden_states
,
og_hidden_dims
=
self
.
_maybe_pad_hidden_states
(
original_hidden_states
,
shared_experts_input
,
hidden_states
,
hidden_states
,
)
)
fused_output
=
self
.
moe_
forward
(
fused_output
=
self
.
forward
_entry
(
hidden_states
,
hidden_states
,
router_logits
,
router_logits
,
original_hidden_states
,
shared_experts_input
,
self
.
_encode_layer_name
(),
self
.
_encode_layer_name
(),
)
)
return
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
return
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
def
forward_dispatch
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
# Sync aux and main stream for shared expert multi-stream overlap.
self
.
_maybe_sync_shared_experts_stream
(
shared_experts_input
)
# If the Runner holds the gate, apply it after the stream sync,
# so it can run overlapped with the
# NOTE: in future PR, MoE runner will always hold the gate.
if
self
.
gate
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
self
.
_maybe_apply_shared_experts
(
shared_experts_input
,
SharedExpertsOrder
.
EXTERNAL
,
)
with
self
.
_sequence_parallel_context
():
return
self
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
def
_slice_and_copy_input
(
def
_slice_and_copy_input
(
self
,
self
,
out_slice
:
torch
.
Tensor
,
out_slice
:
torch
.
Tensor
,
...
@@ -681,17 +677,13 @@ class DefaultMoERunner(MoERunner):
...
@@ -681,17 +677,13 @@ class DefaultMoERunner(MoERunner):
out_slice
.
copy_
(
orig_slice
,
non_blocking
=
True
)
out_slice
.
copy_
(
orig_slice
,
non_blocking
=
True
)
return
out_slice
return
out_slice
def
forward_impl_chunked
(
def
_
forward_impl_chunked
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
shared_
experts_
input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Gate overlap not supported when chunking is enabled. Run the
# gate first.
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
final_shared_hidden_states
,
final_fused_hidden_states
=
(
final_shared_hidden_states
,
final_fused_hidden_states
=
(
self
.
_allocate_dp_chunking_outputs
(
hidden_states
,
router_logits
)
self
.
_allocate_dp_chunking_outputs
(
hidden_states
,
router_logits
)
)
)
...
@@ -737,9 +729,9 @@ class DefaultMoERunner(MoERunner):
...
@@ -737,9 +729,9 @@ class DefaultMoERunner(MoERunner):
chunk_end
,
chunk_end
,
)
)
shared_input_chunk
=
(
shared_
experts_
input_chunk
=
(
shared_input
[
chunk_start
:
chunk_end
,
:]
shared_
experts_
input
[
chunk_start
:
chunk_end
,
:]
if
shared_input
is
not
None
if
shared_
experts_
input
is
not
None
else
None
else
None
)
)
...
@@ -747,7 +739,7 @@ class DefaultMoERunner(MoERunner):
...
@@ -747,7 +739,7 @@ class DefaultMoERunner(MoERunner):
layer
=
layer
,
layer
=
layer
,
hidden_states
=
hidden_states_chunk
,
hidden_states
=
hidden_states_chunk
,
router_logits
=
router_logits_chunk
,
router_logits
=
router_logits_chunk
,
shared_input
=
shared_input_chunk
,
shared_
experts_
input
=
shared_
experts_
input_chunk
,
)
)
# Store outputs
# Store outputs
...
@@ -769,40 +761,13 @@ class DefaultMoERunner(MoERunner):
...
@@ -769,40 +761,13 @@ class DefaultMoERunner(MoERunner):
assert
final_shared_hidden_states
is
not
None
assert
final_shared_hidden_states
is
not
None
return
(
final_shared_hidden_states
,
final_fused_hidden_states
)
return
(
final_shared_hidden_states
,
final_fused_hidden_states
)
def
forward_impl
(
def
_
forward_impl
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
shared_
experts_
input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
use_shared_experts_stream
=
(
current_platform
.
is_cuda
()
and
self
.
has_separate_shared_experts
and
not
self
.
use_dp_chunking
and
self
.
shared_experts_stream
is
not
None
and
(
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
# Check if we need to run shared experts before matrix multiply because
# matrix multiply may modify the hidden_states.
run_shared_experts_before
=
(
self
.
has_separate_shared_experts
and
not
self
.
use_shared_experts_stream
)
# The shared experts stream must be set up before calling the gate so they
# can be overlapped.
if
not
run_shared_experts_before
:
self
.
_maybe_setup_shared_experts_stream
(
hidden_states
,
shared_input
,
)
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
# TODO(bnell): parts of the dispatch/combine steps will go away once
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
# code will probably remain
...
@@ -816,8 +781,7 @@ class DefaultMoERunner(MoERunner):
...
@@ -816,8 +781,7 @@ class DefaultMoERunner(MoERunner):
layer
=
layer
,
layer
=
layer
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
shared_input
=
shared_input
,
shared_experts_input
=
shared_experts_input
,
run_shared_experts_before
=
run_shared_experts_before
,
)
)
return
self
.
_maybe_combine
(
return
self
.
_maybe_combine
(
...
...
vllm/model_executor/layers/fused_moe/runner/moe_runner.py
View file @
7cf56a59
...
@@ -32,3 +32,7 @@ class MoERunner(ABC):
...
@@ -32,3 +32,7 @@ class MoERunner(ABC):
final_hidden_states
:
torch
.
Tensor
,
final_hidden_states
:
torch
.
Tensor
,
):
):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
is_internal_router
(
self
)
->
bool
:
raise
NotImplementedError
vllm/model_executor/layers/fused_moe/runner/shared_experts.py
0 → 100644
View file @
7cf56a59
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
IntEnum
import
torch
import
vllm.envs
as
envs
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
(
aux_stream
,
current_stream
,
)
from
vllm.v1.worker.ubatching
import
(
dbo_current_ubatch_id
,
)
logger
=
init_logger
(
__name__
)
class
SharedExpertsOrder
(
IntEnum
):
# No shared experts.
NONE
=
(
0
,)
# Get rid of this one? combine with BEFORE?
# Note: this might be important for torch.compile reasons. Can
# get rid of it after _moe_forward is undone.
EXTERNAL
=
(
1
,)
# No overlap - defensively called before MK.
NO_OVERLAP
=
(
2
,)
# Overlapped with dispatch/combine in DP/EP - called by the MK.
MK_INTERNAL_OVERLAPPED
=
(
3
,)
# Overlapped with the gate, router, experts in aux stream.
MULTI_STREAM_OVERLAPPED
=
(
4
,)
class
SharedExperts
:
def
__init__
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_config
:
FusedMoEConfig
,
quant_method
:
QuantizeMethodBase
,
reduce_results
:
bool
,
enable_dbo
:
bool
,
):
from
vllm.model_executor.layers.fused_moe.fused_moe_method_base
import
(
FusedMoEMethodBase
,
)
# quant_method must be a FusedMoEMethodBase but we can't use the type
# due to circular imports.
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
# The SharedExperts need to handle DBO since they can be called from
# an MK's finalize method. We keep a list of outputs indexed by current
# DBO ubatch id to handle this case. If DBO is not enabled, the
# index is always 0 and the second output list element is ignored.
self
.
enable_dbo
=
enable_dbo
self
.
_output
:
list
[
torch
.
Tensor
|
None
]
=
[
None
,
None
]
self
.
_layer
=
layer
self
.
_moe_config
=
moe_config
self
.
_quant_method
=
quant_method
self
.
_reduce_results
=
reduce_results
self
.
_use_dp_chunking
=
moe_config
.
moe_parallel_config
.
use_dp_chunking
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if
envs
.
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
logger
.
debug_once
(
"Disabling MoE shared_experts cuda stream"
,
scope
=
"local"
)
self
.
_stream
=
None
else
:
# TODO(rob): enable shared expert overlap with non-cuda-alike.
# aux_stream() returns None on non-cuda-alike platforms.
self
.
_stream
=
aux_stream
()
if
self
.
_stream
is
not
None
:
logger
.
debug_once
(
"Enabled separate cuda stream for MoE shared_experts"
,
scope
=
"local"
)
@
property
def
_has_external_experts
(
self
)
->
bool
:
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain
backend
=
self
.
_moe_config
.
moe_parallel_config
.
all2all_backend
return
not
(
(
self
.
_moe_config
.
moe_parallel_config
.
enable_eplb
and
backend
!=
"allgather_reducescatter"
)
or
self
.
_moe_config
.
moe_parallel_config
.
use_fi_nvl_two_sided_kernels
)
def
_determine_shared_experts_order
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
SharedExpertsOrder
:
if
self
.
_has_external_experts
and
not
self
.
_use_dp_chunking
:
return
SharedExpertsOrder
.
EXTERNAL
if
self
.
_quant_method
.
mk_owns_shared_expert
:
return
SharedExpertsOrder
.
MK_INTERNAL_OVERLAPPED
should_run_shared_in_aux_stream
=
(
current_platform
.
is_cuda
()
and
not
self
.
_use_dp_chunking
and
self
.
_stream
is
not
None
and
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
if
should_run_shared_in_aux_stream
:
return
SharedExpertsOrder
.
MULTI_STREAM_OVERLAPPED
else
:
return
SharedExpertsOrder
.
NO_OVERLAP
def
maybe_sync_shared_experts_stream
(
self
,
shared_experts_input
:
torch
.
Tensor
,
):
experts_order
=
self
.
_determine_shared_experts_order
(
shared_experts_input
)
if
experts_order
==
SharedExpertsOrder
.
MULTI_STREAM_OVERLAPPED
:
assert
self
.
_stream
is
not
None
assert
self
.
_moe_config
.
disable_inplace
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
shared_experts_input
.
record_stream
(
self
.
_stream
)
# Mark sync start point for the aux stream since we will
# run in parallel with router/gate.
self
.
_stream
.
wait_stream
(
current_stream
())
def
_run_in_aux_stream
(
self
,
shared_experts_input
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# TODO: assert that maybe_sync_shared_experts_stream has been called.
# Run shared experts in parallel on a separate stream.
with
torch
.
cuda
.
stream
(
self
.
_stream
):
output
=
self
.
_layer
(
shared_experts_input
)
current_stream
().
wait_stream
(
self
.
_stream
)
return
output
def
_maybe_reduce_shared_out
(
self
,
shared_out
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
if
(
self
.
_reduce_results
and
self
.
_quant_method
.
moe_kernel
is
not
None
and
self
.
_quant_method
.
moe_kernel
.
output_is_reduced
()
and
get_tensor_model_parallel_world_size
()
>
1
):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
return
shared_out
@
property
def
_output_idx
(
self
)
->
int
:
return
dbo_current_ubatch_id
()
if
self
.
enable_dbo
else
0
@
property
def
output
(
self
)
->
torch
.
Tensor
:
assert
self
.
_output
[
self
.
_output_idx
]
is
not
None
output
=
self
.
_output
[
self
.
_output_idx
]
self
.
_output
[
self
.
_output_idx
]
=
None
return
output
def
apply
(
self
,
shared_experts_input
:
torch
.
Tensor
,
order
:
SharedExpertsOrder
,
):
experts_order
=
self
.
_determine_shared_experts_order
(
shared_experts_input
)
if
order
!=
experts_order
:
return
None
assert
self
.
_output
[
self
.
_output_idx
]
is
None
if
order
==
SharedExpertsOrder
.
MULTI_STREAM_OVERLAPPED
:
self
.
_output
[
self
.
_output_idx
]
=
self
.
_run_in_aux_stream
(
shared_experts_input
)
else
:
self
.
_output
[
self
.
_output_idx
]
=
self
.
_layer
(
shared_experts_input
)
if
order
==
SharedExpertsOrder
.
EXTERNAL
:
# TODO: figure out how to combine this with maybe_reduce_output?
# or get rid of it completely.
assert
self
.
_output
[
self
.
_output_idx
]
is
not
None
self
.
_output
[
self
.
_output_idx
]
=
self
.
_maybe_reduce_shared_out
(
self
.
_output
[
self
.
_output_idx
]
)
assert
self
.
_output
[
self
.
_output_idx
]
is
not
None
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment