"vscode:/vscode.git/clone" did not exist on "9a5664d4a4d212a6ebad79b15b11eb8d3ab2a0b2"
Unverified Commit a7c87168 authored by Yang Fan's avatar Yang Fan Committed by GitHub
Browse files

Fix tie_word_embeddings for Qwen2. (#3344)

parent 429284dc
......@@ -299,7 +299,11 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not config.tie_word_embeddings:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
......@@ -318,7 +322,11 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
......@@ -340,6 +348,8 @@ class Qwen2ForCausalLM(nn.Module):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment