Unverified Commit 9875be64 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA][2/2]Remove LoRA extra vocab (#28545)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent df44df01
...@@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -43,13 +43,10 @@ class PunicaWrapperXPU(PunicaWrapperBase):
lora_index_to_id: list[int | None], lora_index_to_id: list[int | None],
max_loras: int, max_loras: int,
vocab_size: int, vocab_size: int,
extra_vocab_size: int,
**kwargs, **kwargs,
): ):
self.is_prefill = mapping.is_prefill self.is_prefill = mapping.is_prefill
self._update_base_metadata( self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
......
...@@ -166,6 +166,16 @@ def parse_fine_tuned_lora_name( ...@@ -166,6 +166,16 @@ def parse_fine_tuned_lora_name(
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights
input_embedding_subfix = ".embed_tokens.base_layer.weight"
output_embedding_subfix = ".lm_head.base_layer.weight"
return name.endswith(input_embedding_subfix) or name.endswith(
output_embedding_subfix
)
def is_regex_target_modules( def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str] load_modules: str | list[str], expected_lora_modules: list[str]
) -> bool: ) -> bool:
......
...@@ -121,8 +121,7 @@ class WorkerLoRAManager: ...@@ -121,8 +121,7 @@ class WorkerLoRAManager:
lora_model_id=lora_request.lora_int_id, lora_model_id=lora_request.lora_int_id,
device="cpu", device="cpu",
dtype=self.lora_config.lora_dtype, dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size target_embedding_padding=self.vocab_size,
+ self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules, embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules, embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict, tensorizer_config_dict=lora_request.tensorizer_config_dict,
...@@ -143,12 +142,6 @@ class WorkerLoRAManager: ...@@ -143,12 +142,6 @@ class WorkerLoRAManager:
# For BadRequestError # For BadRequestError
raise e raise e
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}."
)
return lora return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
......
...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -46,7 +46,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -261,29 +260,16 @@ class GraniteModel(nn.Module): ...@@ -261,29 +260,16 @@ class GraniteModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
): ):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -420,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -420,28 +406,18 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GraniteModel( self.model = GraniteModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -453,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -453,7 +429,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale /= config.logits_scaling logit_scale /= config.logits_scaling
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, scale=logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -368,24 +367,18 @@ class LlamaModel(nn.Module): ...@@ -368,24 +367,18 @@ class LlamaModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
): ):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -562,9 +555,7 @@ class LlamaForCausalLM( ...@@ -562,9 +555,7 @@ class LlamaForCausalLM(
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.model = self._init_model( self.model = self._init_model(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -573,20 +564,9 @@ class LlamaForCausalLM( ...@@ -573,20 +564,9 @@ class LlamaForCausalLM(
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -595,7 +575,7 @@ class LlamaForCausalLM( ...@@ -595,7 +575,7 @@ class LlamaForCausalLM(
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -51,7 +51,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -301,23 +300,18 @@ class MixtralModel(nn.Module): ...@@ -301,23 +300,18 @@ class MixtralModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.enable_eplb = parallel_config.enable_eplb self.enable_eplb = parallel_config.enable_eplb
...@@ -508,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -508,34 +502,24 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel( self.model = MixtralModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
......
...@@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM): ...@@ -74,5 +74,5 @@ class TeleFLMForCausalLM(LlamaForCausalLM):
self.output_mult = self.config.output_mult / self.mup_scale_factor self.output_mult = self.config.output_mult / self.mup_scale_factor
logit_scale = self.output_mult logit_scale = self.output_mult
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, self.config.vocab_size, logit_scale self.config.vocab_size, scale=logit_scale
) )
...@@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -219,9 +219,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size = model_config.get_hidden_size() self.hidden_size = model_config.get_hidden_size()
self.vocab_size = model_config.get_vocab_size() self.vocab_size = model_config.get_vocab_size()
if self.lora_config is not None:
self.vocab_size += self.lora_config.lora_extra_vocab_size
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment