Unverified Commit 4eee77b8 authored by fxmarty-amd's avatar fxmarty-amd Committed by GitHub
Browse files

[fix][MOE] Fix MOE experts `intermediate_size` dimension not being narrowed...


[fix][MOE] Fix MOE experts `intermediate_size` dimension not being narrowed before weight loading (#39688)
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
parent a1993b96
......@@ -257,6 +257,41 @@ class TestWeightLoadingWithPaddedHiddenSize:
assert torch.equal(expert_data_full, loaded_weight)
def test_narrow_shard_dim(self):
"""Simulate loading w2 when both hidden_size and intermediate_size
are padded.
"""
padded_hidden = 3072
original_hidden = 2688
padded_intermediate = 1024
original_intermediate = 896
expert_data_full = torch.zeros(padded_hidden, padded_intermediate)
loaded_weight = torch.randn(original_hidden, original_intermediate)
shard_dim = 1
hidden_dim = FusedMoE._get_hidden_dim(shard_dim=shard_dim, ndim=2)
expert_data = FusedMoE._narrow_expert_data_for_padding(
expert_data_full,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
expert_data.copy_(loaded_weight)
assert torch.equal(
expert_data_full[:original_hidden, :original_intermediate],
loaded_weight,
)
assert torch.equal(
expert_data_full[original_hidden:, :],
torch.zeros(padded_hidden - original_hidden, padded_intermediate),
)
assert torch.equal(
expert_data_full[:original_hidden, original_intermediate:],
torch.zeros(original_hidden, padded_intermediate - original_intermediate),
)
def test_bnb_shape_mismatch_raises(self):
"""BnB + padded hidden_size should raise via weight_loader."""
from unittest.mock import MagicMock
......
......@@ -842,7 +842,10 @@ class FusedMoE(CustomOp):
if shard_id == "w2":
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim
expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
......@@ -882,29 +885,33 @@ class FusedMoE(CustomOp):
expert_data: torch.Tensor,
loaded_weight: torch.Tensor,
hidden_dim: int,
shard_dim: int | None = None,
) -> torch.Tensor:
"""Narrow expert_data hidden dim to match loaded_weight for padded
hidden_size.
"""Narrow expert_data to match loaded_weight for padded dimensions.
When backends (e.g., DeepEP) round up hidden_size, weight parameters
are larger than checkpoint weights. Narrow the padded hidden dimension
before copying.
before copying. Similarly, when padding occurs on the shard
(intermediate) dimension (e.g. for MXFP4 GEMM), narrow that dimension
as well.
Args:
expert_data: The (possibly padded) parameter tensor to narrow.
loaded_weight: The checkpoint weight tensor with original size.
hidden_dim: The dimension index corresponding to hidden_size.
Must be non-negative.
shard_dim: The dimension index corresponding to the shard
(intermediate) dimension. Defaults to `None`.
"""
dims = (hidden_dim,) if shard_dim is None else (hidden_dim, shard_dim)
if loaded_weight.ndim > 0:
for dim in dims:
if (
loaded_weight.ndim > 0
and 0 <= hidden_dim < expert_data.ndim
and hidden_dim < loaded_weight.ndim
and expert_data.shape[hidden_dim] > loaded_weight.shape[hidden_dim]
0 <= dim < expert_data.ndim
and dim < loaded_weight.ndim
and expert_data.shape[dim] > loaded_weight.shape[dim]
):
expert_data = expert_data.narrow(
hidden_dim, 0, loaded_weight.shape[hidden_dim]
)
expert_data = expert_data.narrow(dim, 0, loaded_weight.shape[dim])
return expert_data
def _load_w13(
......@@ -946,7 +953,10 @@ class FusedMoE(CustomOp):
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim
expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
expert_data.copy_(loaded_weight)
......@@ -979,7 +989,10 @@ class FusedMoE(CustomOp):
# w2, down_proj: Load into only logical weight of w2.
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim
expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
expert_data.copy_(loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment