Unverified Commit 14d90617 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Bug: fix lm head weights in Qwen models (#3777)

parent d37f9551
......@@ -379,8 +379,6 @@ class Qwen2ForCausalLM(nn.Module):
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
if name.startswith("lm_head"):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......
......@@ -62,7 +62,11 @@ class Qwen2ForRewardModel(nn.Module):
return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return Qwen2ForCausalLM.load_weights(self, weights)
# Filter out lm_head weights of Qwen2ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen2ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment