Unverified Commit 191d836f authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

fix: minor fix for modelopt weight load compatibility (#7953)

parent 86044712
......@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts=self.local_num_experts,
......@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment