"vscode:/vscode.git/clone" did not exist on "e85c43ab87553426f830f5f49ddc34d7231c5212"
Unverified Commit 45b3a6a2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "[ModelOpt] Fix Weight Loading for DSR1-FP4 Quantization (#9712)" (#10176)

parent 9a18aa54
...@@ -235,9 +235,8 @@ class ReplicatedLinear(LinearBase): ...@@ -235,9 +235,8 @@ class ReplicatedLinear(LinearBase):
loaded_weight = loaded_weight[:1] loaded_weight = loaded_weight[:1]
else: else:
raise ValueError(f"{loaded_weight} are not all equal") raise ValueError(f"{loaded_weight} are not all equal")
assert (
param.size() == loaded_weight.size() assert param.size() == loaded_weight.size()
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
......
...@@ -646,13 +646,6 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -646,13 +646,6 @@ class ModelOptFp4Config(QuantizationConfig):
regex_str = pattern.replace(".", r"\.").replace("*", r".*") regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix): if re.fullmatch(regex_str, prefix):
return True return True
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
pattern_last_part = pattern.split(".")[-1]
prefix_last_part = prefix.split(".")[-1]
if pattern_last_part in prefix_last_part:
return True
return False return False
def get_quant_method( def get_quant_method(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment