Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0aca8b8c
Unverified
Commit
0aca8b8c
authored
Feb 02, 2026
by
danielafrimi
Committed by
GitHub
Feb 02, 2026
Browse files
[MoE] Enable Shared/Routed Overlap For Latent MoE (Nemotron-H) (#32790)
Signed-off-by:
dafrimi
<
dafrimi@nvidia.com
>
parent
9eb58f8c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
303 additions
and
58 deletions
+303
-58
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
+162
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+89
-16
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+22
-0
vllm/model_executor/models/nemotron_h.py
vllm/model_executor/models/nemotron_h.py
+30
-42
No files found.
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
0 → 100644
View file @
0aca8b8c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for SharedFusedMoE with routed_input_transform.
Verifies that applying routed_input_transform inside SharedFusedMoE
produces the same results as applying the transform manually outside.
"""
import
pytest
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
class
SimpleLinear
(
nn
.
Module
):
"""A simple linear transform mimicking latent projection in latent MoE."""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
dtype
:
torch
.
dtype
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_features
,
in_features
,
device
=
"cuda"
,
dtype
=
dtype
)
/
10
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
nn
.
functional
.
linear
(
x
,
self
.
weight
)
class
SimpleSharedExperts
(
nn
.
Module
):
"""A simple 2-layer MLP mimicking shared experts."""
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
dtype
:
torch
.
dtype
):
super
().
__init__
()
self
.
up
=
nn
.
Linear
(
hidden_size
,
intermediate_size
*
2
,
bias
=
False
,
device
=
"cuda"
,
dtype
=
dtype
)
self
.
down
=
nn
.
Linear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
device
=
"cuda"
,
dtype
=
dtype
)
with
torch
.
no_grad
():
self
.
up
.
weight
.
div_
(
10
)
self
.
down
.
weight
.
div_
(
10
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
=
self
.
up
(
x
)
gate
,
up
=
gate_up
.
chunk
(
2
,
dim
=-
1
)
return
self
.
down
(
nn
.
functional
.
silu
(
gate
)
*
up
)
@
pytest
.
fixture
(
autouse
=
True
)
def
setup_cuda
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"hidden_size,latent_size"
,
[(
256
,
128
),
(
128
,
64
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_routed_input_transform_inside_vs_outside
(
num_tokens
:
int
,
hidden_size
:
int
,
latent_size
:
int
,
dtype
:
torch
.
dtype
,
dist_init
,
workspace_init
,
):
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform
"""
torch
.
manual_seed
(
42
)
num_experts
=
8
top_k
=
2
intermediate_size
=
hidden_size
*
2
vllm_config
=
VllmConfig
()
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
shared_experts
=
SimpleSharedExperts
(
hidden_size
,
intermediate_size
,
dtype
)
routed_transform
=
SimpleLinear
(
hidden_size
,
latent_size
,
dtype
)
with
set_current_vllm_config
(
vllm_config
):
# Method A: SharedFusedMoE WITH routed_input_transform
moe_with_transform
=
SharedFusedMoE
(
shared_experts
=
shared_experts
,
routed_input_transform
=
routed_transform
,
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
latent_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
,
params_dtype
=
dtype
,
tp_size
=
1
,
dp_size
=
1
,
pcp_size
=
1
,
prefix
=
"moe_with_transform"
,
)
# Method B: SharedFusedMoE WITHOUT routed_input_transform
# Note: shared_experts=None because when transform is done outside,
moe_without_transform
=
SharedFusedMoE
(
shared_experts
=
None
,
routed_input_transform
=
None
,
num_experts
=
num_experts
,
top_k
=
top_k
,
hidden_size
=
latent_size
,
intermediate_size
=
intermediate_size
,
reduce_results
=
False
,
renormalize
=
True
,
params_dtype
=
dtype
,
tp_size
=
1
,
dp_size
=
1
,
pcp_size
=
1
,
prefix
=
"moe_without_transform"
,
)
with
torch
.
no_grad
():
moe_without_transform
.
w13_weight
.
copy_
(
moe_with_transform
.
w13_weight
)
moe_without_transform
.
w2_weight
.
copy_
(
moe_with_transform
.
w2_weight
)
moe_with_transform
.
quant_method
.
process_weights_after_loading
(
moe_with_transform
)
moe_without_transform
.
quant_method
.
process_weights_after_loading
(
moe_without_transform
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
device
=
"cuda"
,
dtype
=
dtype
)
router_logits
=
torch
.
randn
(
num_tokens
,
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
with
set_forward_context
(
None
,
vllm_config
,
num_tokens
=
num_tokens
):
shared_out_A
,
routed_out_A
=
moe_with_transform
(
hidden_states
,
router_logits
)
transformed_hidden
=
routed_transform
(
hidden_states
)
shared_out_B
,
routed_out_B
=
moe_without_transform
(
transformed_hidden
,
router_logits
)
torch
.
testing
.
assert_close
(
routed_out_A
,
routed_out_B
,
atol
=
1e-3
,
rtol
=
1e-3
,
msg
=
"Routed output should match: transform inside vs outside"
,
)
expected_shared_out
=
shared_experts
(
hidden_states
)
torch
.
testing
.
assert_close
(
shared_out_A
,
expected_shared_out
,
atol
=
1e-3
,
rtol
=
1e-3
,
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
0aca8b8c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Iterable
from
contextlib
import
nullcontext
from
collections.abc
import
Callable
,
Generator
,
Iterable
from
contextlib
import
contextmanager
,
nullcontext
from
enum
import
Enum
from
typing
import
Literal
,
cast
,
get_args
,
overload
...
...
@@ -351,6 +351,10 @@ class FusedMoE(CustomOp):
"Enabled separate cuda stream for MoE shared_experts"
,
scope
=
"local"
)
# For latent MoE: stores original hidden_states before routed_input_transform
# so shared_experts can use it for cloning (they need original dimension)
self
.
_shared_experts_input
:
torch
.
Tensor
|
None
=
None
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
...
...
@@ -664,6 +668,39 @@ class FusedMoE(CustomOp):
def
gate
(
self
)
->
torch
.
nn
.
Module
|
None
:
return
None
def
apply_routed_input_transform
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Hook to transform hidden_states before passing to routed experts.
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
The original hidden_states is saved in _shared_experts_input so
shared_experts still receive the original [S, hidden_size].
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
"""
return
hidden_states
@
contextmanager
def
_set_shared_experts_input
(
self
,
value
:
torch
.
Tensor
|
None
)
->
Generator
[
None
,
None
,
None
]:
"""Context manager to safely set/clear _shared_experts_input."""
self
.
_shared_experts_input
=
value
try
:
yield
finally
:
self
.
_shared_experts_input
=
None
def
_get_shared_experts_input
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Get input for shared experts.
For latent MoE: shared_experts need original [S, hidden_size],
not the transformed [S, latent_size] used by routed experts.
"""
return
(
self
.
_shared_experts_input
if
self
.
_shared_experts_input
is
not
None
else
hidden_states
)
@
property
def
tp_size
(
self
):
return
self
.
moe_parallel_config
.
tp_size
...
...
@@ -855,9 +892,11 @@ class FusedMoE(CustomOp):
if
use_shared_experts_stream
:
assert
self
.
shared_experts_stream
is
not
None
shared_experts_input
=
self
.
_get_shared_experts_input
(
hidden_states
)
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone
=
hidden_states
.
clone
()
hidden_states_clone
=
shared_experts_input
.
clone
()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
...
...
@@ -1537,11 +1576,20 @@ class FusedMoE(CustomOp):
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
og_hidden_states
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
og_hidden_states
:
# For latent MoE: save ORIGINAL hidden_states before transform
# (shared_experts need original dimension, routed experts use transformed)
original_hidden_states
=
hidden_states
original_hidden_dim
=
hidden_states
.
shape
[
-
1
]
# Apply transform for routed experts (e.g., latent projection for latent MoE)
hidden_states
=
self
.
apply_routed_input_transform
(
hidden_states
)
# This is the dimension after transform (for routed expert output slicing)
transformed_hidden_dim
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
transformed_hidden_dim
:
hidden_states
=
F
.
pad
(
hidden_states
,
(
0
,
self
.
hidden_size
-
og
_hidden_
states
),
(
0
,
self
.
hidden_size
-
transformed
_hidden_
dim
),
mode
=
"constant"
,
value
=
0.0
,
)
...
...
@@ -1576,22 +1624,31 @@ class FusedMoE(CustomOp):
fused_output
=
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
encode_layer_name
()
)
return
reduce_output
(
fused_output
)[...,
:
og
_hidden_
states
]
return
reduce_output
(
fused_output
)[...,
:
transformed
_hidden_
dim
]
else
:
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
shared_output
,
fused_output
=
self
.
forward_impl
(
hidden_states
,
router_logits
)
with
self
.
_set_shared_experts_input
(
original_hidden_states
):
shared_output
,
fused_output
=
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
# Custom op handles setting/clearing _shared_experts_input internally
# We pass original tensor for shared experts (not transformed)
shared_output
,
fused_output
=
torch
.
ops
.
vllm
.
moe_forward_shared
(
hidden_states
,
router_logits
,
encode_layer_name
()
hidden_states
,
router_logits
,
encode_layer_name
(),
original_hidden_states
,
)
# shared_output uses original dimension (before transform)
# fused_output uses transformed dimension (after transform)
return
(
reduce_output
(
shared_output
)[...,
:
o
g
_hidden_
states
],
reduce_output
(
fused_output
)[...,
:
og
_hidden_
states
],
reduce_output
(
shared_output
)[...,
:
o
riginal
_hidden_
dim
],
reduce_output
(
fused_output
)[...,
:
transformed
_hidden_
dim
],
)
@
property
...
...
@@ -1831,7 +1888,8 @@ class FusedMoE(CustomOp):
# because matrix multiply maybe modify the hidden_states.
if
has_separate_shared_experts
and
not
use_shared_experts_stream
:
assert
self
.
shared_experts
is
not
None
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_input
=
self
.
_get_shared_experts_input
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
shared_input
)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
...
...
@@ -2021,19 +2079,34 @@ def moe_forward_shared(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_experts_input
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
=
get_layer_from_name
(
layer_name
)
assert
self
.
shared_experts
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
# Set here because torch.compile skips forward_native() setup code
# and calls this op directly. forward_impl() reads from this var.
with
self
.
_set_shared_experts_input
(
shared_experts_input
):
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
def
moe_forward_shared_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_experts_input
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shared_out
=
torch
.
empty_like
(
hidden_states
)
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out
=
torch
.
empty_like
(
hidden_states
)
if
shared_experts_input
is
not
None
:
shared_out
=
torch
.
empty_like
(
shared_experts_input
)
else
:
shared_out
=
torch
.
empty_like
(
hidden_states
)
return
shared_out
,
fused_out
...
...
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
0aca8b8c
...
...
@@ -23,10 +23,12 @@ class SharedFusedMoE(FusedMoE):
shared_experts
:
torch
.
nn
.
Module
|
None
,
gate
:
torch
.
nn
.
Module
|
None
=
None
,
use_overlapped
:
bool
=
True
,
routed_input_transform
:
torch
.
nn
.
Module
|
None
=
None
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
_shared_experts
=
shared_experts
self
.
_routed_input_transform
=
routed_input_transform
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
...
...
@@ -56,6 +58,26 @@ class SharedFusedMoE(FusedMoE):
def
is_internal_router
(
self
)
->
bool
:
return
self
.
gate
is
not
None
def
apply_routed_input_transform
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Apply transform for routed experts (e.g., latent projection).
This is called by FusedMoE.forward_native. The original hidden_states
is saved separately so shared experts get [S, hidden_size] while
routed experts get the transformed [S, moe_latent_size].
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
moved inside SharedFusedMoE to all-reduce on the smaller latent
dimension.
"""
if
self
.
_routed_input_transform
is
not
None
:
result
=
self
.
_routed_input_transform
(
hidden_states
)
# ReplicatedLinear returns (output, extra_bias) tuple.
# We only need the output tensor; extra_bias is not used here.
if
isinstance
(
result
,
tuple
):
return
result
[
0
]
return
result
return
hidden_states
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/nemotron_h.py
View file @
0aca8b8c
...
...
@@ -188,10 +188,29 @@ class NemotronHMoE(nn.Module):
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
if
self
.
use_latent_moe
:
self
.
fc1_latent_proj
=
ReplicatedLinear
(
input_size
=
config
.
hidden_size
,
output_size
=
self
.
moe_hidden_size
,
bias
=
config
.
mlp_bias
,
quant_config
=
quant_config
,
disable_tp
=
self
.
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.fc1_latent_proj"
,
)
self
.
fc2_latent_proj
=
ReplicatedLinear
(
input_size
=
self
.
moe_hidden_size
,
output_size
=
config
.
hidden_size
,
bias
=
config
.
mlp_bias
,
quant_config
=
quant_config
,
disable_tp
=
self
.
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.fc2_latent_proj"
,
)
else
:
self
.
fc1_latent_proj
=
None
self
.
fc2_latent_proj
=
None
self
.
experts
=
SharedFusedMoE
(
# TODO: make it possible for shared experts to have
# different input in SharedFusedMoE
shared_experts
=
self
.
shared_experts
if
not
self
.
use_latent_moe
else
None
,
shared_experts
=
self
.
shared_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
self
.
moe_hidden_size
,
...
...
@@ -211,30 +230,9 @@ class NemotronHMoE(nn.Module):
num_redundant_experts
=
self
.
n_redundant_experts
,
is_sequence_parallel
=
self
.
is_sequence_parallel
,
router_logits_dtype
=
router_logits_dtype
,
routed_input_transform
=
self
.
fc1_latent_proj
,
)
if
self
.
use_latent_moe
:
self
.
fc1_latent_proj
=
ReplicatedLinear
(
input_size
=
config
.
hidden_size
,
output_size
=
self
.
moe_hidden_size
,
bias
=
config
.
mlp_bias
,
quant_config
=
quant_config
,
disable_tp
=
self
.
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.fc1_latent_proj"
,
)
self
.
fc2_latent_proj
=
ReplicatedLinear
(
input_size
=
self
.
moe_hidden_size
,
output_size
=
config
.
hidden_size
,
bias
=
config
.
mlp_bias
,
quant_config
=
quant_config
,
disable_tp
=
self
.
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.fc2_latent_proj"
,
)
else
:
self
.
fc1_latent_proj
=
None
self
.
fc2_latent_proj
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -244,38 +242,28 @@ class NemotronHMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
.
to
(
dtype
=
torch
.
float32
))
shared_output
=
None
if
self
.
use_latent_moe
:
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1_latent_proj
(
hidden_states
)
fused_moe_out
=
self
.
experts
(
# SharedFusedMoE handles:
# - shared experts (with original hidden_states)
# - routed_input_transform (fc1_latent_proj) for latent MoE
# - multistream parallelism between shared and routed experts
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
self
.
use_latent_moe
:
_
,
final_hidden_states
=
fused_moe_out
else
:
shared_output
,
final_hidden_states
=
fused_moe_out
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
*=
self
.
routed_scaling_factor
elif
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
shared_output
*=
1.0
/
self
.
routed_scaling_factor
# TODO: currently latent up_proj is done before all-reduce for simplicity.
# if and when shared experts will be part of SharedFusedMoE,
# we should do the up_proj after all-reduce,
# to have the all-reduce in the smaller latent dimension.
# TODO: See SharedFusedMoE.apply_routed_input_transform
# for bandwidth optimization
if
self
.
use_latent_moe
:
final_hidden_states
,
_
=
self
.
fc2_latent_proj
(
final_hidden_states
)
if
self
.
shared_experts
is
not
None
:
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
self
.
is_sequence_parallel
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment