Unverified Commit 191d836f authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

fix: minor fix for modelopt weight load compatibility (#7953)

parent 86044712
...@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module): ...@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
num_experts=self.local_num_experts, num_experts=self.local_num_experts,
...@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module): ...@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
): ):
raise ValueError("expert_data and loaded_weight must be torch.Tensor") raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if expert_data.dim() != 2 or loaded_weight.dim() != 2: if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError( raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}" f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment