Unverified Commit 7a6ebcbf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Remove unnecessary `get_language_model` (#37545)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c7bc12c2
...@@ -1704,6 +1704,12 @@ class ConformerEncoder(nn.Module): ...@@ -1704,6 +1704,12 @@ class ConformerEncoder(nn.Module):
# ----- Encoder END ----- # ----- Encoder END -----
# This subclass is specific to vLLM in order for
# `_mark_composite_model` to target this module
class CohereASRProjector(nn.Linear):
pass
class CohereASRModel(nn.Module): class CohereASRModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -1714,7 +1720,7 @@ class CohereASRModel(nn.Module): ...@@ -1714,7 +1720,7 @@ class CohereASRModel(nn.Module):
) )
if self.encoder.d_model != self.decoder.hidden_size: if self.encoder.d_model != self.decoder.hidden_size:
self.encoder_decoder_proj = torch.nn.Linear( self.encoder_decoder_proj = CohereASRProjector(
self.encoder.d_model, self.decoder.hidden_size self.encoder.d_model, self.decoder.hidden_size
) )
...@@ -2096,18 +2102,25 @@ class CohereASRForConditionalGeneration( ...@@ -2096,18 +2102,25 @@ class CohereASRForConditionalGeneration(
self.config = config self.config = config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix) with self._mark_composite_model(
lm_head_config = config.head vllm_config,
self.unpadded_vocab_size = lm_head_config["num_classes"] language_targets=CohereASRDecoder,
tower_targets={"audio": (ConformerEncoder, CohereASRProjector)},
):
self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix)
head_config = config.head
self.proj_out = ParallelLMHead( self.proj_out = ParallelLMHead(
lm_head_config["num_classes"], head_config["num_classes"],
lm_head_config["hidden_size"], head_config["hidden_size"],
quant_config=quant_config, quant_config=quant_config,
bias=True, bias=True,
) # NOTE: bias is True ) # NOTE: bias is True
logit_scale = getattr(lm_head_config, "logit_scale", 1.0)
logit_scale = getattr(head_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, lm_head_config["num_classes"], logit_scale head_config["num_classes"], scale=logit_scale
) )
def forward( def forward(
......
...@@ -1373,7 +1373,6 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1373,7 +1373,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
"""compute logits"""
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def _vision_forward( def _vision_forward(
......
...@@ -754,12 +754,17 @@ class FireRedASR2ForConditionalGeneration( ...@@ -754,12 +754,17 @@ class FireRedASR2ForConditionalGeneration(
self.config = config self.config = config
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.model = FireRedASR2Model( with self._mark_composite_model(
vllm_config=vllm_config, vllm_config,
prefix=maybe_prefix(prefix, "model"), language_targets=Qwen2ForCausalLM,
) tower_targets={"audio": (FireRedASR2Encoder, FireRedASR2Adapter)},
logit_scale = getattr(config, "logit_scale", 1.0) ):
self.model = FireRedASR2Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
def forward( def forward(
......
...@@ -470,15 +470,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -470,15 +470,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.vision_config = vision_config self.vision_config = vision_config
self.text_config = text_config self.text_config = text_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
# Initialize Qwen2.5 Vision Transformer
self.visual = Qwen2_5_VisionTransformer(
vision_config=vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
# Linear projector (vision_hidden_size -> text_hidden_size) # Linear projector (vision_hidden_size -> text_hidden_size)
# For V2 model: mm_projector_type is "linear" # For V2 model: mm_projector_type is "linear"
...@@ -492,18 +483,21 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -492,18 +483,21 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
out_hidden = vision_hidden_size out_hidden = vision_hidden_size
# Always create Linear projector since HF checkpoint has mm_projector weights with self._mark_tower_model(vllm_config, {"image", "video"}):
self.mm_projector = nn.Linear(out_hidden, text_hidden_size) self.visual = Qwen2_5_VisionTransformer(
vision_config=vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.mm_projector = nn.Linear(out_hidden, text_hidden_size)
# Language model with self._mark_language_model(vllm_config):
self.lm_head_vocab_size = getattr( self.language_model = init_vllm_registered_model(
text_config, "padded_vocab_size", text_config.vocab_size vllm_config=vllm_config,
) hf_config=text_config,
self.language_model = init_vllm_registered_model( prefix=maybe_prefix(prefix, "language_model"),
vllm_config=vllm_config, )
hf_config=text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -633,9 +627,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -633,9 +627,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return modalities return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
......
...@@ -576,20 +576,19 @@ class InternS1ProForConditionalGeneration( ...@@ -576,20 +576,19 @@ class InternS1ProForConditionalGeneration(
multimodal_config.is_multimodal_pruning_enabled() multimodal_config.is_multimodal_pruning_enabled()
) )
if not multimodal_config.get_limit_per_prompt( with self._mark_tower_model(vllm_config, {"image", "video"}):
"image"
) and not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer( self.visual = Qwen3_VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.language_model = InternS1ProMoeLLMForCausalLM( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") self.language_model = InternS1ProMoeLLMForCausalLM(
) vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
# Whether to include the gate_up_proj mapping is determined by # Whether to include the gate_up_proj mapping is determined by
# the language model. # the language model.
self.packed_modules_mapping = ( self.packed_modules_mapping = (
......
...@@ -15,7 +15,6 @@ from transformers import WhisperConfig as HFWhisperConfig ...@@ -15,7 +15,6 @@ from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
...@@ -54,7 +53,6 @@ from vllm.tokenizers import cached_get_tokenizer ...@@ -54,7 +53,6 @@ from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.transformers_utils.processor import cached_feature_extractor_from_config from vllm.transformers_utils.processor import cached_feature_extractor_from_config
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
from vllm.v1.sample.metadata import SamplingMetadata
# Kimi-Audio constants # Kimi-Audio constants
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3" KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
...@@ -431,28 +429,24 @@ class KimiAudioForConditionalGeneration( ...@@ -431,28 +429,24 @@ class KimiAudioForConditionalGeneration(
) )
] ]
self.audio_tower = KimiAudioWhisperEncoder( with self._mark_tower_model(vllm_config, "audio"):
vllm_config=vllm_config, self.audio_tower = KimiAudioWhisperEncoder(
prefix=maybe_prefix(prefix, "audio_tower"), vllm_config=vllm_config,
) prefix=maybe_prefix(prefix, "audio_tower"),
)
self.multi_modal_projector = KimiAudioMultiModalProjector( self.multi_modal_projector = KimiAudioMultiModalProjector(
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120), whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
llm_dim=self.config.hidden_size, llm_dim=self.config.hidden_size,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(
self.config, architectures=["Qwen2ForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.logits_processor = LogitsProcessor( with self._mark_language_model(vllm_config):
self.config.vocab_size, self.language_model = init_vllm_registered_model(
self.config.vocab_size, vllm_config=vllm_config.with_hf_config(
) self.config, architectures=["Qwen2ForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -595,12 +589,8 @@ class KimiAudioForConditionalGeneration( ...@@ -595,12 +589,8 @@ class KimiAudioForConditionalGeneration(
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata | None = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.logits_processor( return self.language_model.compute_logits(hidden_states)
self.language_model.lm_head, hidden_states, sampling_metadata
)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights, skipping MIMO layers (TTS-only) for ASR.""" """Load weights, skipping MIMO layers (TTS-only) for ASR."""
......
...@@ -163,29 +163,30 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): ...@@ -163,29 +163,30 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_tower = init_vision_tower_for_llava( with self._mark_tower_model(vllm_config, "image"):
config, self.vision_tower = init_vision_tower_for_llava(
quant_config=quant_config, config,
require_post_norm=False, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"), require_post_norm=False,
) prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Mistral3MultiModalProjector( self.multi_modal_projector = Mistral3MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act, projector_hidden_act=config.projector_hidden_act,
spatial_merge_size=config.spatial_merge_size, spatial_merge_size=config.spatial_merge_size,
patch_size=config.vision_config.patch_size, patch_size=config.vision_config.patch_size,
multimodal_projector_bias=config.multimodal_projector_bias, multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
hf_config=config.text_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"), hf_config=config.text_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment