Unverified Commit 4eee77b8 authored by fxmarty-amd's avatar fxmarty-amd Committed by GitHub
Browse files

[fix][MOE] Fix MOE experts `intermediate_size` dimension not being narrowed...


[fix][MOE] Fix MOE experts `intermediate_size` dimension not being narrowed before weight loading (#39688)
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
parent a1993b96
...@@ -257,6 +257,41 @@ class TestWeightLoadingWithPaddedHiddenSize: ...@@ -257,6 +257,41 @@ class TestWeightLoadingWithPaddedHiddenSize:
assert torch.equal(expert_data_full, loaded_weight) assert torch.equal(expert_data_full, loaded_weight)
def test_narrow_shard_dim(self):
"""Simulate loading w2 when both hidden_size and intermediate_size
are padded.
"""
padded_hidden = 3072
original_hidden = 2688
padded_intermediate = 1024
original_intermediate = 896
expert_data_full = torch.zeros(padded_hidden, padded_intermediate)
loaded_weight = torch.randn(original_hidden, original_intermediate)
shard_dim = 1
hidden_dim = FusedMoE._get_hidden_dim(shard_dim=shard_dim, ndim=2)
expert_data = FusedMoE._narrow_expert_data_for_padding(
expert_data_full,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
expert_data.copy_(loaded_weight)
assert torch.equal(
expert_data_full[:original_hidden, :original_intermediate],
loaded_weight,
)
assert torch.equal(
expert_data_full[original_hidden:, :],
torch.zeros(padded_hidden - original_hidden, padded_intermediate),
)
assert torch.equal(
expert_data_full[:original_hidden, original_intermediate:],
torch.zeros(original_hidden, padded_intermediate - original_intermediate),
)
def test_bnb_shape_mismatch_raises(self): def test_bnb_shape_mismatch_raises(self):
"""BnB + padded hidden_size should raise via weight_loader.""" """BnB + padded hidden_size should raise via weight_loader."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
......
...@@ -842,7 +842,10 @@ class FusedMoE(CustomOp): ...@@ -842,7 +842,10 @@ class FusedMoE(CustomOp):
if shard_id == "w2": if shard_id == "w2":
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim) hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding( expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
) )
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"): elif shard_id in ("w1", "w3"):
...@@ -882,29 +885,33 @@ class FusedMoE(CustomOp): ...@@ -882,29 +885,33 @@ class FusedMoE(CustomOp):
expert_data: torch.Tensor, expert_data: torch.Tensor,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
hidden_dim: int, hidden_dim: int,
shard_dim: int | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Narrow expert_data hidden dim to match loaded_weight for padded """Narrow expert_data to match loaded_weight for padded dimensions.
hidden_size.
When backends (e.g., DeepEP) round up hidden_size, weight parameters When backends (e.g., DeepEP) round up hidden_size, weight parameters
are larger than checkpoint weights. Narrow the padded hidden dimension are larger than checkpoint weights. Narrow the padded hidden dimension
before copying. before copying. Similarly, when padding occurs on the shard
(intermediate) dimension (e.g. for MXFP4 GEMM), narrow that dimension
as well.
Args: Args:
expert_data: The (possibly padded) parameter tensor to narrow. expert_data: The (possibly padded) parameter tensor to narrow.
loaded_weight: The checkpoint weight tensor with original size. loaded_weight: The checkpoint weight tensor with original size.
hidden_dim: The dimension index corresponding to hidden_size. hidden_dim: The dimension index corresponding to hidden_size.
Must be non-negative. Must be non-negative.
shard_dim: The dimension index corresponding to the shard
(intermediate) dimension. Defaults to `None`.
""" """
if ( dims = (hidden_dim,) if shard_dim is None else (hidden_dim, shard_dim)
loaded_weight.ndim > 0 if loaded_weight.ndim > 0:
and 0 <= hidden_dim < expert_data.ndim for dim in dims:
and hidden_dim < loaded_weight.ndim if (
and expert_data.shape[hidden_dim] > loaded_weight.shape[hidden_dim] 0 <= dim < expert_data.ndim
): and dim < loaded_weight.ndim
expert_data = expert_data.narrow( and expert_data.shape[dim] > loaded_weight.shape[dim]
hidden_dim, 0, loaded_weight.shape[hidden_dim] ):
) expert_data = expert_data.narrow(dim, 0, loaded_weight.shape[dim])
return expert_data return expert_data
def _load_w13( def _load_w13(
...@@ -946,7 +953,10 @@ class FusedMoE(CustomOp): ...@@ -946,7 +953,10 @@ class FusedMoE(CustomOp):
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim) hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding( expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
) )
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
...@@ -979,7 +989,10 @@ class FusedMoE(CustomOp): ...@@ -979,7 +989,10 @@ class FusedMoE(CustomOp):
# w2, down_proj: Load into only logical weight of w2. # w2, down_proj: Load into only logical weight of w2.
hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim) hidden_dim = self._get_hidden_dim(shard_dim, expert_data.ndim)
expert_data = self._narrow_expert_data_for_padding( expert_data = self._narrow_expert_data_for_padding(
expert_data, loaded_weight, hidden_dim=hidden_dim expert_data,
loaded_weight,
hidden_dim=hidden_dim,
shard_dim=shard_dim,
) )
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment