Unverified Commit 4a9375fe authored by whx's avatar whx Committed by GitHub
Browse files

[Model] Pass param prefix to LLMHead (#24862)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
parent 03191cd8
...@@ -306,6 +306,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -306,6 +306,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
config.n_embd, config.n_embd,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -655,6 +655,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP): ...@@ -655,6 +655,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(self.config.vocab_size) self.logits_processor = LogitsProcessor(self.config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -434,6 +434,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -434,6 +434,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -487,6 +487,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -487,6 +487,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -58,7 +58,7 @@ from vllm.sequence import IntermediateTensors ...@@ -58,7 +58,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers) make_layers, maybe_prefix)
def _is_moe(config: PretrainedConfig) -> bool: def _is_moe(config: PretrainedConfig) -> bool:
...@@ -871,6 +871,7 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP): ...@@ -871,6 +871,7 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -606,6 +606,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -606,6 +606,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
config.text_config.vocab_size, config.text_config.vocab_size,
config.text_config.hidden_size, config.text_config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if self.config.text_config.tie_word_embeddings: if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight self.lm_head.weight = self.model.text_model.wte.weight
......
...@@ -302,7 +302,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -302,7 +302,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
else: else:
self.lm_head = ParallelLMHead(self.config.vocab_size, self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size) self.config.hidden_size,
prefix=maybe_prefix(
prefix, "lm_head"))
if hasattr(config, "width_scale"): if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale self.output_logits_scale = config.width_scale
else: else:
......
...@@ -502,6 +502,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -502,6 +502,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None self.mamba_cache: Optional[MambaCacheManager] = None
......
...@@ -328,6 +328,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -328,6 +328,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.text_config.hidden_size, config.text_config.hidden_size,
org_num_embeddings=self.config.text_config.vocab_size, org_num_embeddings=self.config.text_config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -220,7 +220,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -220,7 +220,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self.config.hidden_size, self.config.hidden_size,
org_num_embeddings=self.config.draft_vocab_size, org_num_embeddings=self.config.draft_vocab_size,
padding_size=(DEFAULT_VOCAB_PADDING_SIZE), padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
prefix="") prefix=maybe_prefix(prefix, "lm_head"))
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
scale=logit_scale) scale=logit_scale)
self.draft_id_to_target_id = nn.Parameter( self.draft_id_to_target_id = nn.Parameter(
......
...@@ -223,6 +223,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): ...@@ -223,6 +223,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
) )
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
......
...@@ -278,6 +278,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -278,6 +278,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
# We need bigger padding if using lora for kernel # We need bigger padding if using lora for kernel
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
......
...@@ -15,6 +15,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -15,6 +15,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from .utils import maybe_prefix
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
...@@ -71,6 +73,7 @@ class Medusa(nn.Module): ...@@ -71,6 +73,7 @@ class Medusa(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=self.truncated_vocab_size, org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.lm_heads = [ self.lm_heads = [
self.lm_head for _ in range(self.config.num_heads) self.lm_head for _ in range(self.config.num_heads)
......
...@@ -158,7 +158,8 @@ class MiMoMTP(nn.Module): ...@@ -158,7 +158,8 @@ class MiMoMTP(nn.Module):
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "model")) prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size, self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size) self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"))
def forward( def forward(
self, self,
......
...@@ -547,6 +547,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -547,6 +547,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
......
...@@ -338,6 +338,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -338,6 +338,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
......
...@@ -702,6 +702,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -702,6 +702,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
self.config.hidden_size, self.config.hidden_size,
org_num_embeddings=self.config.vocab_size, org_num_embeddings=self.config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
......
...@@ -507,6 +507,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -507,6 +507,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -1403,6 +1403,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1403,6 +1403,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
config.embedding_size or config.vocab_size, config.embedding_size or config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(config.embedding_size self.logits_processor = LogitsProcessor(config.embedding_size
......
...@@ -466,6 +466,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -466,6 +466,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment