Unverified Commit 9d1c4747 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA][1/N]Remove LoRA extra vocab (#28382)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 8c32c6e4
......@@ -42,7 +42,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
......@@ -319,22 +318,17 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
......
......@@ -31,7 +31,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
......@@ -400,28 +399,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config
else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.logits_processor = LogitsProcessor(config.vocab_size)
else:
self.lm_head = PPMissingLayer()
......
......@@ -42,7 +42,6 @@ class CausalMixin(VllmModelForTextGeneration):
self.skip_prefixes.append("lm_head.")
if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.text_config.vocab_size,
self.text_config.hidden_size,
......@@ -56,7 +55,7 @@ class CausalMixin(VllmModelForTextGeneration):
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale
self.text_config.vocab_size, scale=logit_scale
)
else:
self.lm_head = PPMissingLayer()
......
......@@ -890,7 +890,7 @@ class WhisperForConditionalGeneration(
self.dtype = vllm_config.model_config.dtype
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.unpadded_vocab_size = config.vocab_size
self.proj_out = ParallelLMHead(
config.vocab_size,
config.d_model,
......@@ -899,9 +899,7 @@ class WhisperForConditionalGeneration(
)
self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
def forward(
self,
......
......@@ -38,7 +38,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
......@@ -692,19 +691,13 @@ class Zamba2Model(nn.Module):
assert not is_lora_enabled
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.vocab_size = config.vocab_size
# Initialize token embeddings
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Map hybrid layer indices to block indices
......@@ -911,7 +904,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
(not supported by Mamba)
"""
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
super().__init__()
......@@ -919,9 +912,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
# Initialize core model
self.model = Zamba2Model(
......@@ -930,23 +920,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
# Initialize language modeling head
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Tie weights with input embeddings if using same dimensions
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment