Unverified Commit 9d1c4747 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA][1/N]Remove LoRA extra vocab (#28382)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 8c32c6e4
...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -49,7 +49,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -346,24 +345,18 @@ class ApertusModel(nn.Module): ...@@ -346,24 +345,18 @@ class ApertusModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
): ):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -518,9 +511,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -518,9 +511,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.model = self._init_model( self.model = self._init_model(
vllm_config=vllm_config, vllm_config=vllm_config,
...@@ -529,20 +520,9 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -529,20 +520,9 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -551,7 +531,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -551,7 +531,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -23,7 +23,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -200,7 +199,6 @@ class ArceeModel(nn.Module): ...@@ -200,7 +199,6 @@ class ArceeModel(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
# Word embeddings (parallelized if using pipeline parallel) # Word embeddings (parallelized if using pipeline parallel)
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
...@@ -209,7 +207,6 @@ class ArceeModel(nn.Module): ...@@ -209,7 +207,6 @@ class ArceeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -383,13 +380,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -383,13 +380,10 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
# Determine vocabulary size (including any LoRA extra tokens # Determine vocabulary size (including any LoRA extra tokens
# for padded LM head) # for padded LM head)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=vllm_config.quant_config, quant_config=vllm_config.quant_config,
bias=getattr(config, "lm_head_bias", False), bias=getattr(config, "lm_head_bias", False),
prefix=f"{prefix}.lm_head", prefix=f"{prefix}.lm_head",
...@@ -399,7 +393,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -399,7 +393,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
# Placeholder for lm_head on non-last ranks # Placeholder for lm_head on non-last ranks
......
...@@ -490,10 +490,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -490,10 +490,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
......
...@@ -547,18 +547,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -547,18 +547,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.pad_token_id = ( self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.config.pad_token_id if self.config.pad_token_id is not None else -1
) )
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.vocab_size,
config.text_config.hidden_size, config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale)
self.unpadded_vocab_size, self.vocab_size, logit_scale
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
......
...@@ -402,9 +402,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant ...@@ -402,9 +402,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel( self.model = BaiChuanModel(
......
...@@ -581,10 +581,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -581,10 +581,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
config = vllm_config.model_config.hf_config.get_text_config() config = vllm_config.model_config.hf_config.get_text_config()
vllm_config.model_config.hf_config = config vllm_config.model_config.hf_config = config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.model = BailingMoeModel( self.model = BailingMoeModel(
......
...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -284,21 +283,14 @@ class BambaModel(nn.Module): ...@@ -284,21 +283,14 @@ class BambaModel(nn.Module):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
def get_layer(prefix: str): def get_layer(prefix: str):
...@@ -478,7 +470,7 @@ class BambaForCausalLM( ...@@ -478,7 +470,7 @@ class BambaForCausalLM(
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
...@@ -488,24 +480,14 @@ class BambaForCausalLM( ...@@ -488,24 +480,14 @@ class BambaForCausalLM(
self.model = BambaModel( self.model = BambaModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
......
...@@ -963,9 +963,9 @@ class ChameleonForConditionalGeneration( ...@@ -963,9 +963,9 @@ class ChameleonForConditionalGeneration(
self.model = ChameleonModel( self.model = ChameleonModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -973,9 +973,7 @@ class ChameleonForConditionalGeneration( ...@@ -973,9 +973,7 @@ class ChameleonForConditionalGeneration(
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
......
...@@ -433,10 +433,9 @@ class ChatGLMBaseModel(nn.Module): ...@@ -433,10 +433,9 @@ class ChatGLMBaseModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.lora_config = lora_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -288,17 +288,12 @@ class CohereModel(nn.Module): ...@@ -288,17 +288,12 @@ class CohereModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.config = config self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size config.vocab_size, config.hidden_size
) )
...@@ -424,17 +419,15 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -424,17 +419,15 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing command R models have `tie_word_embeddings` # currently all existing command R models have `tie_word_embeddings`
# enabled # enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale config.vocab_size, scale=config.logit_scale
) )
self.model = CohereModel( self.model = CohereModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
......
...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -25,7 +25,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -441,21 +440,17 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -441,21 +440,17 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
if config.tie_word_embeddings: if config.tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.") raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel( self.transformer = DbrxModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
) )
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
......
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -323,16 +322,11 @@ class ExaoneModel(nn.Module): ...@@ -323,16 +322,11 @@ class ExaoneModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.wte = config.vocab_size self.wte = config.vocab_size
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
...@@ -340,7 +334,6 @@ class ExaoneModel(nn.Module): ...@@ -340,7 +334,6 @@ class ExaoneModel(nn.Module):
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -489,10 +482,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -489,10 +482,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = ExaoneModel( self.transformer = ExaoneModel(
...@@ -500,18 +492,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -500,18 +492,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
prefix=maybe_prefix(prefix, "model"), prefix=maybe_prefix(prefix, "model"),
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -520,7 +503,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -520,7 +503,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -311,23 +310,17 @@ class Exaone4Model(nn.Module): ...@@ -311,23 +310,17 @@ class Exaone4Model(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
): ):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
...@@ -476,10 +469,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -476,10 +469,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Exaone4Model( self.model = Exaone4Model(
...@@ -487,18 +478,9 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -487,18 +478,9 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
prefix=maybe_prefix(prefix, "model"), prefix=maybe_prefix(prefix, "model"),
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -507,7 +489,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -507,7 +489,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -30,7 +30,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -424,21 +423,15 @@ class FalconH1Model(nn.Module): ...@@ -424,21 +423,15 @@ class FalconH1Model(nn.Module):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.embedding_multiplier = config.embedding_multiplier self.embedding_multiplier = config.embedding_multiplier
else: else:
...@@ -572,7 +565,7 @@ class FalconH1ForCausalLM( ...@@ -572,7 +565,7 @@ class FalconH1ForCausalLM(
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
...@@ -584,21 +577,11 @@ class FalconH1ForCausalLM( ...@@ -584,21 +577,11 @@ class FalconH1ForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.tie_word_embeddings = config.tie_word_embeddings self.tie_word_embeddings = config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.lm_head_multiplier = config.lm_head_multiplier self.lm_head_multiplier = config.lm_head_multiplier
...@@ -607,7 +590,7 @@ class FalconH1ForCausalLM( ...@@ -607,7 +590,7 @@ class FalconH1ForCausalLM(
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size,
config.vocab_size, config.vocab_size,
scale=config.lm_head_multiplier, scale=config.lm_head_multiplier,
) )
......
...@@ -382,12 +382,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -382,12 +382,10 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel( self.model = GemmaModel(
......
...@@ -393,8 +393,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -393,8 +393,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled
......
...@@ -524,8 +524,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -524,8 +524,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled
......
...@@ -1114,8 +1114,7 @@ class Gemma3nForCausalLM(nn.Module): ...@@ -1114,8 +1114,7 @@ class Gemma3nForCausalLM(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
......
...@@ -248,10 +248,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -248,10 +248,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Glm4Model( self.model = Glm4Model(
......
...@@ -207,18 +207,13 @@ class GPTBigCodeModel(nn.Module): ...@@ -207,18 +207,13 @@ class GPTBigCodeModel(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
) )
...@@ -290,10 +285,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -290,10 +285,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel( self.transformer = GPTBigCodeModel(
...@@ -305,15 +298,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -305,15 +298,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.transformer.vocab_size, self.transformer.vocab_size,
self.transformer.embed_dim, self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
self.unpadded_vocab_size = config.vocab_size
if lora_config: self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment