Unverified Commit 14d90617 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Bug: fix lm head weights in Qwen models (#3777)

parent d37f9551
...@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module):
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
if name.startswith("lm_head"):
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
...@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module):
return EmbeddingPoolerOutput(pooled_logits) return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return Qwen2ForCausalLM.load_weights(self, weights) # Filter out lm_head weights of Qwen2ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [ EntryClass = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment