Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3896e021
Unverified
Commit
3896e021
authored
Mar 31, 2026
by
SandishKumarHN
Committed by
GitHub
Mar 31, 2026
Browse files
[Bugfix] Fix FusedMoE weight loading with padded hidden dimensions (#37010)
Signed-off-by:
SandishKumarHN
<
sandish@fb.com
>
parent
b6e636c1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
374 additions
and
14 deletions
+374
-14
tests/kernels/moe/test_moe_weight_loading_padded.py
tests/kernels/moe/test_moe_weight_loading_padded.py
+292
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+82
-14
No files found.
tests/kernels/moe/test_moe_weight_loading_padded.py
0 → 100644
View file @
3896e021
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for FusedMoE weight loading with padded hidden dimensions.
When using DeepEP backends or NIXL EP with models like nemotron_h,
hidden_size may be rounded up (e.g., 2688 -> 3072) for backend requirements.
Weight parameters are created with the padded size, but checkpoint weights
have the original unpadded size. These tests verify that weight loading
correctly handles this mismatch.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
class
TestGetHiddenDim
:
"""Unit tests for _get_hidden_dim."""
def
test_2d_non_transposed_w2
(
self
):
# w2: shard_dim=1 (intermediate), hidden=0
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
2
)
==
0
def
test_2d_non_transposed_w13
(
self
):
# w1/w3: shard_dim=0 (intermediate), hidden=1
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
2
)
==
1
def
test_2d_transposed_w2
(
self
):
# transposed w2: shard_dim=0, hidden=1
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
2
)
==
1
def
test_2d_transposed_w13
(
self
):
# transposed w1/w3: shard_dim=1, hidden=0
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
2
)
==
0
def
test_3d_non_transposed_w2
(
self
):
# 3D w2: shard_dim=2, hidden=1
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
2
,
ndim
=
3
)
==
1
def
test_3d_non_transposed_w13
(
self
):
# 3D w1/w3: shard_dim=1, hidden=2
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
3
)
==
2
def
test_3d_transposed_w2
(
self
):
# transposed 3D w2: shard_dim=1, hidden=2
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
3
)
==
2
def
test_3d_transposed_w13
(
self
):
# transposed 3D w1/w3: shard_dim=2, hidden=1
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
2
,
ndim
=
3
)
==
1
def
test_1d_returns_zero
(
self
):
# 1D per-channel scales: always returns 0
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
1
)
==
0
assert
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
1
)
==
0
def
test_invalid_shard_dim_raises
(
self
):
# shard_dim outside the data dimensions should raise
with
pytest
.
raises
(
ValueError
,
match
=
"not a valid data dimension"
):
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
3
)
class
TestNarrowExpertDataForPadding
:
"""Unit tests for _narrow_expert_data_for_padding."""
def
test_no_narrowing_when_shapes_match
(
self
):
expert_data
=
torch
.
zeros
(
1024
,
1024
)
loaded_weight
=
torch
.
randn
(
1024
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
assert
result
.
shape
==
loaded_weight
.
shape
assert
result
.
data_ptr
()
==
expert_data
.
data_ptr
()
def
test_narrow_w2_hidden_dim
(
self
):
# w2: (hidden_size, intermediate_size) - hidden_size padded at dim 0
expert_data
=
torch
.
zeros
(
3072
,
1024
)
loaded_weight
=
torch
.
randn
(
2688
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
assert
result
.
shape
==
(
2688
,
1024
)
def
test_narrow_w13_hidden_dim
(
self
):
# w1/w3: (intermediate_size, hidden_size) - hidden_size padded at dim 1
expert_data
=
torch
.
zeros
(
2048
,
3072
)
loaded_weight
=
torch
.
randn
(
2048
,
2688
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
1
)
assert
result
.
shape
==
(
2048
,
2688
)
def
test_narrow_transposed_w2
(
self
):
# transposed w2: (intermediate_size, hidden_size) - hidden at dim 1
expert_data
=
torch
.
zeros
(
1024
,
3072
)
loaded_weight
=
torch
.
randn
(
1024
,
2688
)
hidden_dim
=
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
2
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
assert
result
.
shape
==
(
1024
,
2688
)
def
test_narrow_3d_full_load
(
self
):
# 3D tensor for full_load path: w2 (num_experts, hidden_size, intermediate)
expert_data
=
torch
.
zeros
(
8
,
3072
,
1024
)
loaded_weight
=
torch
.
randn
(
8
,
2688
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
1
)
assert
result
.
shape
==
(
8
,
2688
,
1024
)
def
test_narrow_1d_scale
(
self
):
# 1D scale tensor: per-channel w2 scale (hidden_size,)
expert_data
=
torch
.
zeros
(
3072
)
loaded_weight
=
torch
.
randn
(
2688
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
assert
result
.
shape
==
(
2688
,)
def
test_scalar_weight_no_op
(
self
):
# 0-dim tensor should be a no-op
expert_data
=
torch
.
zeros
(
3072
)
loaded_weight
=
torch
.
tensor
(
1.0
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
# ndim == 0, so no narrowing
assert
result
.
shape
==
(
3072
,)
def
test_no_narrowing_when_loaded_weight_larger
(
self
):
# Guard: don't narrow if loaded_weight is larger than expert_data
expert_data
=
torch
.
zeros
(
2688
,
1024
)
loaded_weight
=
torch
.
randn
(
3072
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
assert
result
.
shape
==
(
2688
,
1024
)
assert
result
.
data_ptr
()
==
expert_data
.
data_ptr
()
def
test_negative_hidden_dim_is_noop
(
self
):
# Negative hidden_dim should be a safe no-op (0 <= check)
expert_data
=
torch
.
zeros
(
3072
,
1024
)
loaded_weight
=
torch
.
randn
(
2688
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=-
1
)
# -1 fails the 0 <= check, so no narrowing
assert
result
.
shape
==
(
3072
,
1024
)
assert
result
.
data_ptr
()
==
expert_data
.
data_ptr
()
def
test_only_narrows_hidden_dim
(
self
):
# Verify that only the specified hidden_dim is narrowed,
# even when other dimensions also differ
expert_data
=
torch
.
zeros
(
3072
,
2048
)
loaded_weight
=
torch
.
randn
(
2688
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
# Only dim 0 (hidden) should be narrowed; dim 1 stays at 2048
assert
result
.
shape
==
(
2688
,
2048
)
def
test_narrowed_data_shares_storage
(
self
):
# Verify narrowing returns a view (writes go to original tensor)
expert_data
=
torch
.
zeros
(
3072
,
1024
)
loaded_weight
=
torch
.
randn
(
2688
,
1024
)
result
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
0
)
result
.
copy_
(
loaded_weight
)
# The first 2688 rows of expert_data should now have loaded_weight
assert
torch
.
equal
(
expert_data
[:
2688
,
:],
loaded_weight
)
# Padded region should remain zero
assert
torch
.
equal
(
expert_data
[
2688
:,
:],
torch
.
zeros
(
3072
-
2688
,
1024
))
class
TestWeightLoadingWithPaddedHiddenSize
:
"""Integration-style tests that simulate padded weight loading."""
def
test_load_w2_with_padding
(
self
):
"""Simulate loading w2 weights when hidden_size is padded."""
padded_hidden
=
3072
original_hidden
=
2688
intermediate
=
1024
expert_data_full
=
torch
.
zeros
(
padded_hidden
,
intermediate
)
loaded_weight
=
torch
.
randn
(
original_hidden
,
intermediate
)
# w2 non-transposed: shard_dim=1, hidden_dim=0
hidden_dim
=
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
2
)
expert_data
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data_full
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
expert_data
.
copy_
(
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
[:
original_hidden
,
:],
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
[
original_hidden
:,
:],
torch
.
zeros
(
padded_hidden
-
original_hidden
,
intermediate
),
)
def
test_load_w13_with_padding
(
self
):
"""Simulate loading w1/w3 weights when hidden_size is padded."""
padded_hidden
=
3072
original_hidden
=
2688
intermediate
=
1024
# w1/w3: (intermediate_size, hidden_size)
expert_data_full
=
torch
.
zeros
(
intermediate
,
padded_hidden
)
loaded_weight
=
torch
.
randn
(
intermediate
,
original_hidden
)
# w1 non-transposed: shard_dim=0, hidden_dim=1
hidden_dim
=
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
2
)
expert_data
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data_full
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
expert_data
.
copy_
(
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
[:,
:
original_hidden
],
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
[:,
original_hidden
:],
torch
.
zeros
(
intermediate
,
padded_hidden
-
original_hidden
),
)
def
test_load_transposed_w2_with_padding
(
self
):
"""Simulate loading transposed w2 (GPTQ) with padded hidden_size."""
padded_hidden
=
3072
original_hidden
=
2688
intermediate
=
1024
# transposed w2: (intermediate_size, hidden_size), shard_dim=0
expert_data_full
=
torch
.
zeros
(
intermediate
,
padded_hidden
)
loaded_weight
=
torch
.
randn
(
intermediate
,
original_hidden
)
hidden_dim
=
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
0
,
ndim
=
2
)
expert_data
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data_full
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
expert_data
.
copy_
(
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
[:,
:
original_hidden
],
loaded_weight
)
def
test_no_padding_is_noop
(
self
):
"""Verify that when sizes match, behavior is unchanged."""
hidden
=
2048
intermediate
=
1024
expert_data_full
=
torch
.
zeros
(
hidden
,
intermediate
)
loaded_weight
=
torch
.
randn
(
hidden
,
intermediate
)
hidden_dim
=
FusedMoE
.
_get_hidden_dim
(
shard_dim
=
1
,
ndim
=
2
)
expert_data
=
FusedMoE
.
_narrow_expert_data_for_padding
(
expert_data_full
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
expert_data
.
copy_
(
loaded_weight
)
assert
torch
.
equal
(
expert_data_full
,
loaded_weight
)
def
test_bnb_shape_mismatch_raises
(
self
):
"""BnB + padded hidden_size should raise via weight_loader."""
from
unittest.mock
import
MagicMock
num_experts
=
1
padded_packed
=
3072
# padded packed size
original_packed
=
2688
# original packed size
# Build a param that looks like a BnB 4-bit MoE weight.
param_data
=
torch
.
zeros
(
num_experts
,
padded_packed
,
1
,
dtype
=
torch
.
uint8
)
param
=
torch
.
nn
.
Parameter
(
param_data
,
requires_grad
=
False
)
param
.
use_bitsandbytes_4bit
=
True
loaded_weight
=
torch
.
randint
(
0
,
255
,
(
original_packed
,
1
),
dtype
=
torch
.
uint8
)
# Minimal FusedMoE mock so weight_loader reaches the BnB path.
moe
=
MagicMock
(
spec
=
FusedMoE
)
moe
.
quant_config
=
None
moe
.
quant_method
=
MagicMock
()
moe
.
quant_method
.
__class__
.
__name__
=
"BitsAndBytesMethod"
moe
.
_expert_map
=
None
moe
.
tp_rank
=
0
# Call the real weight_loader (unbound) with our mock as self.
with
pytest
.
raises
(
ValueError
,
match
=
"BitsAndBytes"
):
FusedMoE
.
weight_loader
(
moe
,
param
,
loaded_weight
,
weight_name
=
"w2"
,
shard_id
=
"w2"
,
expert_id
=
0
,
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
3896e021
...
@@ -860,6 +860,10 @@ class FusedMoE(CustomOp):
...
@@ -860,6 +860,10 @@ class FusedMoE(CustomOp):
):
):
# for per channel weight quantization
# for per channel weight quantization
if
shard_id
==
"w2"
:
if
shard_id
==
"w2"
:
hidden_dim
=
self
.
_get_hidden_dim
(
shard_dim
,
expert_data
.
ndim
)
expert_data
=
self
.
_narrow_expert_data_for_padding
(
expert_data
,
loaded_weight
,
hidden_dim
=
hidden_dim
)
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
self
.
_load_w13
(
...
@@ -870,6 +874,59 @@ class FusedMoE(CustomOp):
...
@@ -870,6 +874,59 @@ class FusedMoE(CustomOp):
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
)
)
@
staticmethod
def
_get_hidden_dim
(
shard_dim
:
int
,
ndim
:
int
)
->
int
:
"""Compute the hidden dimension index from the shard (intermediate)
dimension and tensor rank.
For 2D weight tensors the two data dims are (0, 1). For 3D tensors
with an expert dimension at dim 0, they are (1, 2). ``shard_dim``
occupies one of these; the hidden dimension is the other.
For 1D tensors (e.g. per-channel scales) returns 0.
"""
if
ndim
<
2
:
return
0
dim_a
=
ndim
-
2
dim_b
=
ndim
-
1
if
shard_dim
==
dim_a
:
return
dim_b
if
shard_dim
==
dim_b
:
return
dim_a
raise
ValueError
(
f
"shard_dim=
{
shard_dim
}
is not a valid data dimension "
f
"for a
{
ndim
}
D tensor (expected
{
dim_a
}
or
{
dim_b
}
)"
)
@
staticmethod
def
_narrow_expert_data_for_padding
(
expert_data
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
hidden_dim
:
int
,
)
->
torch
.
Tensor
:
"""Narrow expert_data hidden dim to match loaded_weight for padded
hidden_size.
When backends (e.g., DeepEP) round up hidden_size, weight parameters
are larger than checkpoint weights. Narrow the padded hidden dimension
before copying.
Args:
expert_data: The (possibly padded) parameter tensor to narrow.
loaded_weight: The checkpoint weight tensor with original size.
hidden_dim: The dimension index corresponding to hidden_size.
Must be non-negative.
"""
if
(
loaded_weight
.
ndim
>
0
and
0
<=
hidden_dim
<
expert_data
.
ndim
and
hidden_dim
<
loaded_weight
.
ndim
and
expert_data
.
shape
[
hidden_dim
]
>
loaded_weight
.
shape
[
hidden_dim
]
):
expert_data
=
expert_data
.
narrow
(
hidden_dim
,
0
,
loaded_weight
.
shape
[
hidden_dim
]
)
return
expert_data
def
_load_w13
(
def
_load_w13
(
self
,
self
,
expert_data
:
torch
.
Tensor
,
expert_data
:
torch
.
Tensor
,
...
@@ -907,13 +964,10 @@ class FusedMoE(CustomOp):
...
@@ -907,13 +964,10 @@ class FusedMoE(CustomOp):
else
:
else
:
assert
shard_id
==
"w3"
assert
shard_id
==
"w3"
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
hidden_dim
=
self
.
_get_hidden_dim
(
shard_dim
,
expert_data
.
ndim
)
# Handle padding: if loaded_weight is smaller than expert_data (can happen
expert_data
=
self
.
_narrow_expert_data_for_padding
(
# on last TP shard with padding), copy to top-left corner
expert_data
,
loaded_weight
,
hidden_dim
=
hidden_dim
if
expert_data
.
shape
!=
loaded_weight
.
shape
:
)
expert_data
=
expert_data
[
:
loaded_weight
.
shape
[
0
],
:
loaded_weight
.
shape
[
1
]
]
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
def
_load_w2
(
...
@@ -943,12 +997,10 @@ class FusedMoE(CustomOp):
...
@@ -943,12 +997,10 @@ class FusedMoE(CustomOp):
narrow_size
=
min
(
shard_size
,
available
)
narrow_size
=
min
(
shard_size
,
available
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_offset
,
narrow_size
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_offset
,
narrow_size
)
# w2, down_proj: Load into only logical weight of w2.
# w2, down_proj: Load into only logical weight of w2.
# Handle padding: if loaded_weight is smaller than expert_data (can happen
hidden_dim
=
self
.
_get_hidden_dim
(
shard_dim
,
expert_data
.
ndim
)
# on last TP shard with padding), copy to top-left corner
expert_data
=
self
.
_narrow_expert_data_for_padding
(
if
expert_data
.
shape
!=
loaded_weight
.
shape
:
expert_data
,
loaded_weight
,
hidden_dim
=
hidden_dim
expert_data
=
expert_data
[
)
:
loaded_weight
.
shape
[
0
],
:
loaded_weight
.
shape
[
1
]
]
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_single_value
(
def
_load_single_value
(
...
@@ -1095,9 +1147,25 @@ class FusedMoE(CustomOp):
...
@@ -1095,9 +1147,25 @@ class FusedMoE(CustomOp):
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
if
shard_id
==
"w2"
:
if
shard_id
==
"w2"
:
# BnB params are stored as flat packed tensors (e.g.
# (packed_size, 1)), not in the logical weight layout.
# Narrowing packed data for hidden-dim padding is not
# meaningful, so require an exact shape match.
if
expert_data
.
shape
!=
loaded_weight
.
shape
:
raise
ValueError
(
"BitsAndBytes quantization with padded hidden_size "
"(e.g., from DeepEP) is not supported. "
f
"Parameter shape
{
tuple
(
expert_data
.
shape
)
}
!= "
f
"checkpoint shape
{
tuple
(
loaded_weight
.
shape
)
}
"
)
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
elif
shard_id
in
(
"w1"
,
"w3"
):
# BNB inflight quantization has already sharded the weights
# BnB stores weights as flat packed tensors. _load_w13 is
# still used to split the w1/w3 portions along shard_dim.
# _narrow_expert_data_for_padding will be a no-op since
# packed sizes should already match; if DeepEP padding
# causes a mismatch the copy_() will fail with a clear
# shape error.
full_load
=
True
full_load
=
True
self
.
_load_w13
(
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment