Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9279c59a
Unverified
Commit
9279c59a
authored
Mar 19, 2026
by
bnellnm
Committed by
GitHub
Mar 19, 2026
Browse files
[MoE Refactor] DefaultMoERunner simplifcation (#33049)
Signed-off-by:
Bill Nell
<
bnell@redhat.com
>
parent
74540961
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
379 additions
and
295 deletions
+379
-295
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-0
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
...el_executor/layers/fused_moe/runner/default_moe_runner.py
+377
-295
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
9279c59a
...
...
@@ -504,6 +504,8 @@ class FusedMoE(CustomOp):
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
self
.
activation
=
MoEActivation
.
from_str
(
activation
)
# TODO(bnell): we should not have to create a router if the kernel is
# monolithic.
self
.
router
=
create_fused_moe_router
(
top_k
=
top_k
,
global_num_experts
=
self
.
global_num_experts
,
...
...
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
View file @
9279c59a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
...
...
@@ -82,9 +83,22 @@ def _moe_forward(
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
)
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
def
_moe_forward_fake
(
...
...
@@ -105,9 +119,22 @@ def _moe_forward_shared(
layer
=
get_layer_from_name
(
_resolve_layer_name
(
layer_name
))
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
)
runner
=
layer
.
runner
with
runner
.
_sequence_parallel_context
():
if
runner
.
use_dp_chunking
:
return
runner
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
else
:
return
runner
.
forward_impl
(
layer
,
hidden_states
,
router_logits
,
shared_experts_input
,
)
def
_moe_forward_shared_fake
(
...
...
@@ -191,10 +218,17 @@ class DefaultMoERunner(MoERunner):
self
.
reduce_results
=
reduce_results
self
.
enable_dbo
=
enable_dbo
# Chunked all2all staging tensor
# TODO(bnell) rename these?
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
self
.
_maybe_init_dp_chunking
()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
self
.
use_shared_experts_stream
=
False
if
envs
.
VLLM_DISABLE_SHARED_EXPERTS_STREAM
:
logger
.
debug_once
(
"Disabling MoE shared_experts cuda stream"
,
scope
=
"local"
)
self
.
shared_experts_stream
=
None
...
...
@@ -210,23 +244,20 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self
.
layer_name
=
layer
.
layer_name
self
.
moe_forward
=
self
.
_select_forward
(
layer
)
def
_select_forward
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Callable
:
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
if
self
.
shared_experts
is
None
:
self
.
moe_forward
=
_moe_forward
else
:
self
.
moe_forward
=
_moe_forward_shared
else
:
if
self
.
shared_experts
is
None
:
self
.
moe_forward
=
torch
.
ops
.
vllm
.
moe_forward
else
:
self
.
moe_forward
=
torch
.
ops
.
vllm
.
moe_forward_shared
return
_moe_forward
if
self
.
shared_experts
is
None
else
_moe_forward_shared
# Chunked all2all staging tensor
self
.
batched_hidden_states
:
torch
.
Tensor
|
None
=
None
self
.
batched_router_logits
:
torch
.
Tensor
|
None
=
None
return
(
torch
.
ops
.
vllm
.
moe_forward
if
self
.
shared_experts
is
None
else
torch
.
ops
.
vllm
.
moe_forward_shared
)
@
property
def
use_dp_chunking
(
self
)
->
bool
:
...
...
@@ -241,22 +272,8 @@ class DefaultMoERunner(MoERunner):
self
,
hidden_states
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
has_separate_shared_experts
:
bool
,
use_chunked_impl
:
bool
,
)
->
tuple
[
bool
,
torch
.
Tensor
|
None
]:
use_shared_experts_stream
=
(
current_platform
.
is_cuda
()
and
has_separate_shared_experts
and
not
use_chunked_impl
and
self
.
shared_experts_stream
is
not
None
and
(
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
shared_experts_input
:
torch
.
Tensor
|
None
=
None
if
use_shared_experts_stream
:
):
if
self
.
use_shared_experts_stream
:
assert
self
.
shared_experts_stream
is
not
None
assert
self
.
moe_config
.
disable_inplace
...
...
@@ -278,12 +295,11 @@ class DefaultMoERunner(MoERunner):
assert
self
.
shared_experts_stream
is
not
None
self
.
shared_experts_stream
.
wait_stream
(
current_stream
())
return
use_shared_experts_stream
,
shared_experts_input
def
ensure_dp_chunking_init
(
self
):
if
not
self
.
use_dp_chunking
or
self
.
batched_hidden_states
is
not
None
:
def
_maybe_init_dp_chunking
(
self
):
if
not
self
.
use_dp_chunking
:
return
assert
self
.
batched_hidden_states
is
None
states_shape
:
tuple
[
int
,
...]
logits_shape
:
tuple
[
int
,
...]
...
...
@@ -309,6 +325,38 @@ class DefaultMoERunner(MoERunner):
device
=
device
,
)
@
property
def
has_separate_shared_experts
(
self
)
->
bool
:
return
(
not
self
.
quant_method
.
mk_owns_shared_expert
and
self
.
shared_experts
is
not
None
)
def
_apply_shared_experts
(
self
,
hidden_states
:
torch
.
Tensor
,
allow_streaming
:
bool
=
False
,
)
->
torch
.
Tensor
|
None
:
shared_output
:
torch
.
Tensor
|
None
=
None
if
self
.
has_separate_shared_experts
:
assert
self
.
shared_experts
is
not
None
if
self
.
use_shared_experts_stream
and
allow_streaming
:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with
torch
.
cuda
.
stream
(
self
.
shared_experts_stream
):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output
=
self
.
shared_experts
(
hidden_states
)
current_stream
().
wait_stream
(
self
.
shared_experts_stream
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
return
shared_output
def
must_reduce_shared_expert_outputs
(
self
)
->
bool
:
"""
The shared_experts are typically computed using the RowParallelLinear
...
...
@@ -322,7 +370,6 @@ class DefaultMoERunner(MoERunner):
Therefore it is required that we reduce the shared_experts output
early.
"""
assert
self
.
quant_method
is
not
None
return
(
self
.
quant_method
.
moe_kernel
is
not
None
and
self
.
quant_method
.
moe_kernel
.
output_is_reduced
()
...
...
@@ -357,7 +404,7 @@ class DefaultMoERunner(MoERunner):
return
result
return
hidden_states
def
_reduce_output
(
def
_maybe
_reduce_output
(
self
,
states
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
trunc_sizes
:
list
[
int
],
...
...
@@ -397,23 +444,16 @@ class DefaultMoERunner(MoERunner):
return
"from_forward_context"
return
self
.
layer_name
def
forward
(
def
_maybe_pad_hidden_states
(
self
,
original_hidden_states
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if
self
.
shared_experts
is
not
None
:
original_hidden_states
=
hidden_states
original_hidden_dim
=
hidden_states
.
shape
[
-
1
]
else
:
original_hidden_states
=
None
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
# This is the dimension after transform (for routed expert output slicing)
)
->
tuple
[
torch
.
Tensor
,
list
[
int
]]:
original_hidden_dim
=
(
original_hidden_states
.
shape
[
-
1
]
if
original_hidden_states
is
not
None
else
0
)
transformed_hidden_dim
=
hidden_states
.
shape
[
-
1
]
if
(
not
self
.
quant_method
.
skip_forward_padding
...
...
@@ -426,134 +466,235 @@ class DefaultMoERunner(MoERunner):
value
=
0.0
,
)
fused_output
=
self
.
moe_forward
(
hidden_states
,
router_logits
,
original_hidden_states
,
self
.
_encode_layer_name
(),
)
if
self
.
shared_experts
is
not
None
:
orig_hidden_dims
=
[
original_hidden_dim
,
transformed_hidden_dim
]
else
:
orig_hidden_dims
=
[
transformed_hidden_dim
]
return
self
.
_reduce_output
(
fused_output
,
orig_hidden_dims
)
return
hidden_states
,
orig_hidden_dims
def
forward_impl_chunke
d
(
def
_apply_quant_metho
d
(
self
,
layer
:
torch
.
nn
.
Module
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
,
full_shared_input
:
torch
.
Tensor
|
None
,
has_separate_shared_experts
:
bool
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
run_shared_experts_before
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
shared_input
=
shared_input
if
shared_input
is
not
None
else
hidden_states
shared_output
:
torch
.
Tensor
|
None
=
None
# Run this before quant_method to avoid inplace issues.
if
run_shared_experts_before
:
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
False
)
if
self
.
quant_method
.
is_monolithic
:
result
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
result
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
isinstance
(
result
,
tuple
):
assert
shared_output
is
None
shared_output
,
hidden_states
=
result
else
:
hidden_states
=
result
if
not
run_shared_experts_before
and
self
.
has_separate_shared_experts
:
assert
shared_output
is
None
shared_output
=
self
.
_apply_shared_experts
(
shared_input
,
True
)
return
shared_output
,
hidden_states
def
_sequence_parallel_context
(
self
):
ctx
=
get_forward_context
()
return
(
ctx
.
dp_metadata
.
sp_local_sizes
(
self
.
moe_config
.
sp_size
)
if
ctx
.
dp_metadata
else
nullcontext
()
)
def
_allocate_dp_chunking_outputs
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
assert
self
.
use_dp_chunking
# Assert the inputs are of the proper type and shape.
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_router_logits
is
not
None
assert
self
.
batched_hidden_states
.
dtype
==
full_hidden_states
.
dtype
,
(
f
"
{
self
.
batched_hidden_states
.
dtype
}
==
{
full_hidden_states
.
dtype
}
"
assert
self
.
batched_hidden_states
.
dtype
==
hidden_states
.
dtype
,
(
f
"
{
self
.
batched_hidden_states
.
dtype
}
==
{
hidden_states
.
dtype
}
"
)
assert
self
.
batched_router_logits
.
dtype
==
full_
router_logits
.
dtype
,
(
f
"
{
self
.
batched_router_logits
.
dtype
}
==
{
full_
router_logits
.
dtype
}
"
assert
self
.
batched_router_logits
.
dtype
==
router_logits
.
dtype
,
(
f
"
{
self
.
batched_router_logits
.
dtype
}
==
{
router_logits
.
dtype
}
"
)
# Check size compatibility.
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
full_hidden_states
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
full_router_logits
.
size
(
-
1
)
# TODO(bnell): Fix shared_expert_inputs w/chunking.
# assert shared_input is None, (
# "Routed input transform is not currently supported with DP chunking."
# )
# Check size compatibility.
assert
self
.
batched_hidden_states
.
size
(
-
1
)
==
hidden_states
.
size
(
-
1
)
assert
self
.
batched_router_logits
.
size
(
-
1
)
==
router_logits
.
size
(
-
1
)
f
ul
l_fused_
final_
hidden_states
=
torch
.
empty_like
(
full_
hidden_states
)
f
ina
l_fused_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
self
.
shared_experts
is
not
None
:
full_shared_final_hidden_states
=
torch
.
empty_like
(
full_hidden_states
)
def
process_chunk
(
chunk_start
,
chunk_end
,
skip_result_store
=
False
):
chunk_size
=
chunk_end
-
chunk_start
hidden_states
=
full_hidden_states
[
chunk_start
:
chunk_end
,
:]
router_logits
=
full_router_logits
[
chunk_start
:
chunk_end
,
:]
shared_input
=
(
full_shared_input
[
chunk_start
:
chunk_end
,
:]
if
full_shared_input
is
not
None
else
None
)
final_shared_hidden_states
=
torch
.
empty_like
(
hidden_states
)
else
:
final_shared_hidden_states
=
None
assert
self
.
batched_hidden_states
is
not
None
assert
self
.
batched_router_logits
is
not
None
# This is only true when DBO has been enabled in the config.
# Both tensors will have an outer dimension for the ubatch id
if
self
.
batched_hidden_states
.
dim
()
==
3
:
assert
self
.
batched_router_logits
.
dim
()
==
3
batch_buffer_idx
=
dbo_current_ubatch_id
()
batched_hidden_states
=
self
.
batched_hidden_states
[
batch_buffer_idx
,
:]
batched_router_logits
=
self
.
batched_router_logits
[
batch_buffer_idx
,
:]
else
:
batched_hidden_states
=
self
.
batched_hidden_states
batched_router_logits
=
self
.
batched_router_logits
return
final_shared_hidden_states
,
final_fused_hidden_states
def
_maybe_gate
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if
self
.
gate
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
return
router_logits
@
property
def
do_naive_dispatch_combine
(
self
)
->
bool
:
return
(
self
.
moe_config
.
dp_size
>
1
and
not
self
.
quant_method
.
supports_internal_mk
)
assert
(
batched_hidden_states
.
size
(
0
)
# type: ignore
>=
chunk_size
def
_maybe_dispatch
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if
self
.
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch_router_logits
(
hidden_states
,
router_logits
,
self
.
moe_config
.
is_sequence_parallel
,
)
assert
(
batched_router_logits
.
size
(
0
)
# type: ignore
>=
chunk_size
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstraction to better support PCP.
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
all_gather
(
hidden_states
,
dim
=
0
,
)
staged_hidden_states
=
batched_hidden_states
[:
chunk_size
,
:]
# type: ignore
staged_router_logits
=
batched_router_logits
[:
chunk_size
,
:]
# type: ignore
staged_hidden_states
.
copy_
(
hidden_states
,
non_blocking
=
True
)
staged_router_logits
.
copy_
(
router_logits
,
non_blocking
=
True
)
router_logits
=
get_pcp_group
().
all_gather
(
router_logits
,
dim
=
0
,
)
return
hidden_states
,
router_logits
shared_input
=
(
shared_input
if
shared_input
is
not
None
else
staged_hidden_states
def
_maybe_combine
(
self
,
shared_output
:
torch
.
Tensor
|
None
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
self
.
do_naive_dispatch_combine
:
hidden_states
=
get_ep_group
().
combine
(
hidden_states
,
self
.
moe_config
.
is_sequence_parallel
)
# Matrix multiply.
if
self
.
quant_method
.
is_monolithic
:
assert
has_separate_shared_experts
or
self
.
shared_experts
is
None
final_hidden_states
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
staged_hidden_states
,
router_logits
=
staged_router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
staged_hidden_states
,
router_logits
=
staged_router_logits
,
)
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
reduce_scatter
(
hidden_states
,
dim
=
0
,
)
# need RS for shared_output?
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
staged_hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
return
shared_output
,
hidden_states
else
:
return
hidden_states
if
has_separate_shared_experts
:
assert
not
isinstance
(
final_hidden_states
,
tuple
)
assert
self
.
shared_experts
is
not
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
if
self
.
shared_experts
is
not
None
:
original_hidden_states
=
hidden_states
else
:
original_hidden_states
=
None
shared_output
=
self
.
shared_experts
(
shared_input
)
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
final
_hidden_states
=
(
shared_output
,
final_
hidden_states
,
)
hidden_states
,
og_hidden_dims
=
self
.
_maybe_pad
_hidden_states
(
original_hidden_states
,
hidden_states
,
)
if
not
skip_result_store
:
if
self
.
shared_experts
is
None
:
full_fused_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
,
non_blocking
=
True
)
else
:
full_shared_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
[
0
],
non_blocking
=
True
)
full_fused_final_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
final_hidden_states
[
1
],
non_blocking
=
True
)
fused_output
=
self
.
moe_forward
(
hidden_states
,
router_logits
,
original_hidden_states
,
self
.
_encode_layer_name
(),
)
return
self
.
_maybe_reduce_output
(
fused_output
,
og_hidden_dims
)
def
_slice_and_copy_input
(
self
,
out_slice
:
torch
.
Tensor
,
orig
:
torch
.
Tensor
|
None
,
start
:
int
,
end
:
int
,
)
->
torch
.
Tensor
:
assert
orig
is
not
None
slice_size
=
end
-
start
orig_slice
=
orig
[
start
:
end
,
:]
if
self
.
enable_dbo
:
assert
out_slice
.
dim
()
==
3
batch_buffer_idx
=
dbo_current_ubatch_id
()
out_slice
=
out_slice
[
batch_buffer_idx
,
:]
assert
out_slice
.
size
(
0
)
>=
slice_size
out_slice
=
out_slice
[:
slice_size
,
:]
out_slice
.
copy_
(
orig_slice
,
non_blocking
=
True
)
return
out_slice
def
forward_impl_chunked
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Gate overlap not supported when chunking is enabled. Run the
# gate first.
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
final_shared_hidden_states
,
final_fused_hidden_states
=
(
self
.
_allocate_dp_chunking_outputs
(
hidden_states
,
router_logits
)
)
ctx
=
get_forward_context
()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
...
...
@@ -567,7 +708,7 @@ class DefaultMoERunner(MoERunner):
max_tokens_across_dispatchers
,
self
.
moe_config
.
sp_size
)
num_tokens
=
full_
hidden_states
.
size
(
0
)
num_tokens
=
hidden_states
.
size
(
0
)
for
chunk_idx
,
chunk_start_
in
enumerate
(
range
(
0
,
max_tokens_across_dispatchers
,
moe_dp_chunk_size_per_rank
)
):
...
...
@@ -578,17 +719,55 @@ class DefaultMoERunner(MoERunner):
# clamp start and end
chunk_start
=
min
(
chunk_start
,
num_tokens
-
1
)
chunk_end
=
min
(
chunk_end
,
num_tokens
)
with
ctx
.
dp_metadata
.
chunked_sizes
(
chunk_sizes
=
ctx
.
dp_metadata
.
chunked_sizes
(
self
.
moe_config
.
sp_size
,
moe_dp_chunk_size_per_rank
,
chunk_idx
):
process_chunk
(
chunk_start
,
chunk_end
,
skip_result_store
=
chunk_start_
>=
num_tokens
)
with
chunk_sizes
:
hidden_states_chunk
=
self
.
_slice_and_copy_input
(
self
.
batched_hidden_states
,
hidden_states
,
chunk_start
,
chunk_end
,
)
router_logits_chunk
=
self
.
_slice_and_copy_input
(
self
.
batched_router_logits
,
router_logits
,
chunk_start
,
chunk_end
,
)
shared_input_chunk
=
(
shared_input
[
chunk_start
:
chunk_end
,
:]
if
shared_input
is
not
None
else
None
)
shared_output_chunk
,
hidden_states_chunk
=
self
.
_apply_quant_method
(
layer
=
layer
,
hidden_states
=
hidden_states_chunk
,
router_logits
=
router_logits_chunk
,
shared_input
=
shared_input_chunk
,
)
# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if
chunk_start
<
num_tokens
:
final_fused_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
hidden_states_chunk
,
non_blocking
=
True
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output_chunk
is
not
None
assert
final_shared_hidden_states
is
not
None
final_shared_hidden_states
[
chunk_start
:
chunk_end
,
:].
copy_
(
shared_output_chunk
,
non_blocking
=
True
)
if
self
.
shared_experts
is
None
:
return
f
ul
l_fused_
final_
hidden_states
return
f
ina
l_fused_hidden_states
else
:
return
(
full_shared_final_hidden_states
,
full_fused_final_hidden_states
)
assert
final_shared_hidden_states
is
not
None
return
(
final_shared_hidden_states
,
final_fused_hidden_states
)
def
forward_impl
(
self
,
...
...
@@ -597,148 +776,51 @@ class DefaultMoERunner(MoERunner):
router_logits
:
torch
.
Tensor
,
shared_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
quant_method
is
not
None
self
.
ensure_dp_chunking_init
()
has_separate_shared_experts
=
(
not
self
.
quant_method
.
mk_owns_shared_expert
and
self
.
shared_experts
is
not
None
self
.
use_shared_experts_stream
=
(
current_platform
.
is_cuda
()
and
self
.
has_separate_shared_experts
and
not
self
.
use_dp_chunking
and
self
.
shared_experts_stream
is
not
None
and
(
hidden_states
.
shape
[
0
]
<=
envs
.
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
use_chunked_impl
=
self
.
use_dp_chunking
# Check if we need to run shared experts before matrix multiply because
# matrix multiply may modify the hidden_states.
run_shared_experts_before
=
(
self
.
has_separate_shared_experts
and
not
self
.
use_shared_experts_stream
)
use_shared_experts_stream
,
shared_experts_input
=
(
# The shared experts stream must be set up before calling the gate so they
# can be overlapped.
if
not
run_shared_experts_before
:
self
.
_maybe_setup_shared_experts_stream
(
hidden_states
,
shared_input
,
has_separate_shared_experts
,
use_chunked_impl
,
)
)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if
self
.
gate
is
not
None
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
use_chunked_impl
:
return
self
.
forward_impl_chunked
(
layer
,
hidden_states
,
router_logits
,
shared_input
,
has_separate_shared_experts
,
)
router_logits
=
self
.
_maybe_gate
(
hidden_states
,
router_logits
)
# NOTE(rob): once we finish migrating all the quant methods to use
# MKs, we can remove the naive dispatch/combine path from here.
do_naive_dispatch_combine
=
(
self
.
moe_config
.
dp_size
>
1
and
not
self
.
quant_method
.
supports_internal_mk
# TODO(bnell): parts of the dispatch/combine steps will go away once
# #32567 lands and the remaining kernels are made MKs. The PCP
# code will probably remain
hidden_states
,
router_logits
=
self
.
_maybe_dispatch
(
layer
,
hidden_states
,
router_logits
,
)
ctx
=
get_forward_context
()
sp_ctx
=
(
ctx
.
dp_metadata
.
sp_local_sizes
(
self
.
moe_config
.
sp_size
)
if
ctx
.
dp_metadata
else
nullcontext
()
shared_output
,
hidden_states
=
self
.
_apply_quant_method
(
layer
=
layer
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
shared_input
=
shared_input
,
run_shared_experts_before
=
run_shared_experts_before
,
)
with
sp_ctx
:
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if
has_separate_shared_experts
and
not
use_shared_experts_stream
:
assert
self
.
shared_experts
is
not
None
shared_input
=
(
shared_input
if
shared_input
is
not
None
else
hidden_states
)
shared_output
=
self
.
shared_experts
(
shared_input
)
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
# router logits to all experts.
# NOTE: this will be removed once all kernels are migrated into the
# MoEKernel framework.
if
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch_router_logits
(
hidden_states
,
router_logits
,
self
.
moe_config
.
is_sequence_parallel
,
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if
self
.
moe_config
.
pcp_size
>
1
:
hidden_states
=
get_pcp_group
().
all_gather
(
hidden_states
,
dim
=
0
,
)
router_logits
=
get_pcp_group
().
all_gather
(
router_logits
,
dim
=
0
,
)
# Matrix multiply.
if
self
.
quant_method
.
is_monolithic
:
final_hidden_states
=
self
.
quant_method
.
apply_monolithic
(
layer
=
layer
,
x
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
topk_weights
,
topk_ids
=
self
.
router
.
select_experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
layer
,
x
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
shared_experts_input
=
shared_input
,
)
if
has_separate_shared_experts
:
assert
self
.
shared_experts
is
not
None
if
use_shared_experts_stream
:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with
torch
.
cuda
.
stream
(
self
.
shared_experts_stream
):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output
=
self
.
shared_experts
(
shared_experts_input
)
current_stream
().
wait_stream
(
self
.
shared_experts_stream
)
final_hidden_states
=
(
shared_output
,
final_hidden_states
,
)
def
combine_output
(
states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
do_naive_dispatch_combine
:
states
=
get_ep_group
().
combine
(
states
,
self
.
moe_config
.
is_sequence_parallel
)
if
self
.
moe_config
.
pcp_size
>
1
:
states
=
get_pcp_group
().
reduce_scatter
(
states
,
dim
=
0
,
)
return
states
if
self
.
shared_experts
is
not
None
:
return
(
final_hidden_states
[
0
],
combine_output
(
final_hidden_states
[
1
]),
)
else
:
return
combine_output
(
final_hidden_states
)
return
self
.
_maybe_combine
(
shared_output
,
hidden_states
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment