Unverified Commit 9d1c4747 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA][1/N]Remove LoRA extra vocab (#28382)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 8c32c6e4
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -319,22 +318,17 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -319,22 +318,17 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.lm_head", prefix=f"{prefix}.lm_head",
) )
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
......
...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -400,28 +399,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -400,28 +399,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config
else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -42,7 +42,6 @@ class CausalMixin(VllmModelForTextGeneration): ...@@ -42,7 +42,6 @@ class CausalMixin(VllmModelForTextGeneration):
self.skip_prefixes.append("lm_head.") self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.text_config.vocab_size, self.text_config.vocab_size,
self.text_config.hidden_size, self.text_config.hidden_size,
...@@ -56,7 +55,7 @@ class CausalMixin(VllmModelForTextGeneration): ...@@ -56,7 +55,7 @@ class CausalMixin(VllmModelForTextGeneration):
logit_scale = getattr(self.text_config, "logit_scale", 1.0) logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale self.text_config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -890,7 +890,7 @@ class WhisperForConditionalGeneration( ...@@ -890,7 +890,7 @@ class WhisperForConditionalGeneration(
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.unpadded_vocab_size = config.vocab_size
self.proj_out = ParallelLMHead( self.proj_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
...@@ -899,9 +899,7 @@ class WhisperForConditionalGeneration( ...@@ -899,9 +899,7 @@ class WhisperForConditionalGeneration(
) )
self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens) self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
def forward( def forward(
self, self,
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -692,19 +691,13 @@ class Zamba2Model(nn.Module): ...@@ -692,19 +691,13 @@ class Zamba2Model(nn.Module):
assert not is_lora_enabled assert not is_lora_enabled
self.config = config self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
# Initialize token embeddings # Initialize token embeddings
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
# Map hybrid layer indices to block indices # Map hybrid layer indices to block indices
...@@ -911,7 +904,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -911,7 +904,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
(not supported by Mamba) (not supported by Mamba)
""" """
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
super().__init__() super().__init__()
...@@ -919,9 +912,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -919,9 +912,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
# Initialize core model # Initialize core model
self.model = Zamba2Model( self.model = Zamba2Model(
...@@ -930,23 +920,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -930,23 +920,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
# Initialize language modeling head # Initialize language modeling head
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
# Tie weights with input embeddings if using same dimensions # Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Initialize logits processing and sampling # Initialize logits processing and sampling
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings. """Convert input token IDs to embeddings.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment