Unverified Commit f4fc7337 authored by Zijian Hu's avatar Zijian Hu Committed by GitHub
Browse files

[Bugfix] support `tie_word_embeddings` for all models (#5724)

parent 0df7ec0b
...@@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module): ...@@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config) self.model = OPTModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.word_embed_proj_dim)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module): ...@@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super().__init__() super().__init__()
self.config = config self.config = config
# lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
) )
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)
...@@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal): ...@@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module): ...@@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module): ...@@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment