Unverified Commit f4fc7337 authored by Zijian Hu's avatar Zijian Hu Committed by GitHub
Browse files

[Bugfix] support `tie_word_embeddings` for all models (#5724)

parent 0df7ec0b
...@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module): ...@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
......
...@@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module): ...@@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
lora_config: Optional[LoRAConfig] = None): lora_config: Optional[LoRAConfig] = None):
super().__init__() super().__init__()
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config self.config = config
self.model = BartModel(config, self.model = BartModel(config,
cache_config, cache_config,
......
...@@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__() super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module): ...@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config) self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head = self.transformer.word_embeddings if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length",
8192) 8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module): ...@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module): ...@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
): ):
super().__init__() super().__init__()
self.config = config self.config = config
if config.tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config) self.transformer = DbrxModel(config, cache_config, quant_config)
......
...@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module): ...@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del lora_config # Unused. del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module): ...@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config, cache_config,
quant_config, quant_config,
prefix="transformer") prefix="transformer")
self.lm_head = self.transformer.wte if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config, self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config) lora_config)
self.lm_head = self.transformer.wte if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output = ParallelLMHead(config.vocab_size, self.output = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
...@@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module): ...@@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)
self.lm_head = self.transformer.wte if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
else: else:
......
...@@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in: additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
...@@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -331,7 +331,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values: The pixels in each input image. pixel_values: The pixels in each input image.
See also: See also:
:class:`LlavaImageInputs` :class:`LlavaImageInputs`
""" """
......
...@@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -545,7 +545,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
9047, 13566, 29901]`. 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in: additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
...@@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
batch. batch.
pixel_values: The pixels in each grid patch for each input image. pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image. image_sizes: The original `(height, width)` for each input image.
See also: See also:
:class:`LlavaNextImageInputs` :class:`LlavaNextImageInputs`
""" """
......
...@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
) )
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment