Unverified Commit fda3f03e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[4/N] Initialize MM components in context managers (M-P) (#32663)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb917203
...@@ -621,14 +621,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -621,14 +621,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs: object, **kwargs: object,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None: if intermediate_tensors is not None:
multimodal_embeddings = self.embed_multimodal(**kwargs) inputs_embeds = None
inputs_embeds = self.embed_input_ids(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids, input_ids,
......
...@@ -791,14 +791,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -791,14 +791,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -71,8 +71,6 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: ...@@ -71,8 +71,6 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
class LMMissingLayer(nn.Module): class LMMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def make_empty_intermediate_tensors(self, *args, **kwargs): def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode") raise RuntimeError("This module should not be called in MM encoder-only mode")
...@@ -81,8 +79,6 @@ class LMMissingLayer(nn.Module): ...@@ -81,8 +79,6 @@ class LMMissingLayer(nn.Module):
class TowerMissingLayer(nn.Module): class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str] | str) -> None: def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str): if isinstance(modalities, str):
modalities = {modalities} modalities = {modalities}
...@@ -92,7 +88,10 @@ class TowerMissingLayer(nn.Module): ...@@ -92,7 +88,10 @@ class TowerMissingLayer(nn.Module):
self.modalities = modalities self.modalities = modalities
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise RuntimeError(f"The following modalities are disabled: {self.modalities}") raise RuntimeError(
f"This module should not be called when the following "
f"modalities are disabled: {self.modalities}"
)
@contextmanager @contextmanager
......
...@@ -789,7 +789,6 @@ class InternS1ForConditionalGeneration( ...@@ -789,7 +789,6 @@ class InternS1ForConditionalGeneration(
**kwargs: object, **kwargs: object,
) -> IntermediateTensors: ) -> IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
......
...@@ -1379,7 +1379,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1379,7 +1379,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
**kwargs: object, **kwargs: object,
) -> IntermediateTensors: ) -> IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
......
...@@ -707,30 +707,30 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -707,30 +707,30 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config
# Initialize audio components with self._mark_tower_model(vllm_config, "audio"):
self.audio_encoder = DashengAudioTransformer( self.audio_encoder = DashengAudioTransformer(
config.audio_encoder_config, config.audio_encoder_config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "audio_encoder"), prefix=maybe_prefix(prefix, "audio_encoder"),
) )
self.audio_projector = AudioProjectorSubsample( self.audio_projector = AudioProjectorSubsample(
in_dim=config.audio_encoder_config.embed_dim, in_dim=config.audio_encoder_config.embed_dim,
out_dim=config.text_config.hidden_size, out_dim=config.text_config.hidden_size,
downsample_rate=config.subsample_factor, downsample_rate=config.subsample_factor,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "audio_projector"), prefix=maybe_prefix(prefix, "audio_projector"),
) )
# Initialize language model (decoder) with self._mark_language_model(vllm_config):
self.decoder = init_vllm_registered_model( self.decoder = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
hf_config=config.text_config, hf_config=config.text_config,
prefix=maybe_prefix(prefix, "decoder"), prefix=maybe_prefix(prefix, "decoder"),
architectures=["Qwen2ForCausalLM"], architectures=["Qwen2ForCausalLM"],
) )
self.quant_config = quant_config
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.decoder.make_empty_intermediate_tensors self.decoder.make_empty_intermediate_tensors
) )
...@@ -787,9 +787,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -787,9 +787,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
return torch.split(masked_audio_features, audio_output_lengths.tolist()) return torch.split(masked_audio_features, audio_output_lengths.tolist())
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
......
...@@ -553,9 +553,11 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -553,9 +553,11 @@ class MiniCPMO(MiniCPMV2_6):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
self.apm = self.init_audio_module(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") with self._mark_tower_model(vllm_config, "audio"):
) self.apm = self.init_audio_module(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
)
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily # Do not use parameters temporarily
......
...@@ -1028,25 +1028,27 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1028,25 +1028,27 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm")
)
self.vpm = self.init_vision_module(
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
)
self.vision_dim = (
self.vpm.embed_dim
if self.version == (2, 0)
else self.vpm.embeddings.embed_dim
)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler( with self._mark_language_model(vllm_config):
self.embed_dim, self.llm = self.init_llm(
self.vision_dim, vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm")
quant_config=quant_config, )
prefix=maybe_prefix(prefix, "resampler"),
) with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vpm = vpm = self.init_vision_module(
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
)
self.vision_dim = (
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim
)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(
self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "resampler"),
)
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
...@@ -1134,9 +1136,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1134,9 +1136,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
......
...@@ -201,28 +201,33 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -201,28 +201,33 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings. with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
quant_config=quant_config, quant_config=quant_config,
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
require_post_norm=False, require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"), prefix=maybe_prefix(prefix, "vision_tower"),
) )
self.multi_modal_projector = MiniMaxVL01MultiModalProjector( self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act, projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=True, multimodal_projector_bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"), prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.image_newline = nn.Parameter(
self.language_model = init_vllm_registered_model( torch.empty(config.text_config.hidden_size)
vllm_config=vllm_config, )
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"), with self._mark_language_model(vllm_config):
) self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_feature_layer = config.vision_feature_layer self.vision_feature_layer = config.vision_feature_layer
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.pad_token_id = -1 self.pad_token_id = -1
...@@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _image_pixels_to_features( def _image_pixels_to_features(
self, self,
vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel, vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
...@@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self, self,
inputs: MiniMaxVL01ImagePixelInputs, inputs: MiniMaxVL01ImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]: ) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values) return self._image_pixels_to_features(self.vision_tower, pixel_values)
...@@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor): if isinstance(image_features, torch.Tensor):
...@@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
......
...@@ -1441,15 +1441,20 @@ class MolmoForCausalLM( ...@@ -1441,15 +1441,20 @@ class MolmoForCausalLM(
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
vision_config = VisionBackboneConfig() vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(
config, with self._mark_tower_model(vllm_config, "image"):
vision_config, self.vision_backbone = MolmoVisionBackbone(
quant_config, config,
prefix=maybe_prefix(prefix, "vision_backbone"), vision_config,
) quant_config,
self.model = MolmoModel( prefix=maybe_prefix(prefix, "vision_backbone"),
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") )
)
with self._mark_language_model(vllm_config):
self.model = MolmoModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.img_patch_id = None self.img_patch_id = None
if self.config.weight_tying: if self.config.weight_tying:
...@@ -1525,9 +1530,6 @@ class MolmoForCausalLM( ...@@ -1525,9 +1530,6 @@ class MolmoForCausalLM(
results.append(feats[is_valid][order]) results.append(feats[is_valid][order])
return results return results
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -2514,16 +2514,19 @@ class Molmo2ForConditionalGeneration( ...@@ -2514,16 +2514,19 @@ class Molmo2ForConditionalGeneration(
kwargs[field.name] = getattr(config.adapter_config, field.name) kwargs[field.name] = getattr(config.adapter_config, field.name)
adapter_config = AdapterConfig(**kwargs) adapter_config = AdapterConfig(**kwargs)
self.vision_backbone = Molmo2VisionBackbone( with self._mark_tower_model(vllm_config, {"image", "video"}):
vit_config, self.vision_backbone = Molmo2VisionBackbone(
adapter_config, vit_config,
quant_config, adapter_config,
prefix=maybe_prefix(prefix, "vision_backbone"), quant_config,
) prefix=maybe_prefix(prefix, "vision_backbone"),
self.model = Molmo2TextModel( )
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"), with self._mark_language_model(vllm_config):
) self.model = Molmo2TextModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.img_patch_id = config.image_patch_id self.img_patch_id = config.image_patch_id
...@@ -2687,9 +2690,6 @@ class Molmo2ForConditionalGeneration( ...@@ -2687,9 +2690,6 @@ class Molmo2ForConditionalGeneration(
out.append(out_features) out.append(out_features)
return tuple(out) return tuple(out)
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
......
...@@ -1511,34 +1511,38 @@ class NemotronH_Nano_VL_V2( ...@@ -1511,34 +1511,38 @@ class NemotronH_Nano_VL_V2(
self.ps_version = config.ps_version self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type self.image_tag_type = config.image_tag_type
self.video_pruning_rate = multimodal_config.video_pruning_rate self.video_pruning_rate = multimodal_config.video_pruning_rate
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.dtype
)
# Construct the vision projection. with self._mark_language_model(vllm_config):
vit_hidden_size = config.vit_hidden_size self.language_model = language_model = init_vllm_registered_model(
vision_projection_hidden_size = config.projector_hidden_size vllm_config=vllm_config,
llm_hidden_size = config.text_config.hidden_size hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
self.mlp1 = nn.Sequential( )
RMSNorm(
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, with self._mark_tower_model(vllm_config, {"image", "video"}):
eps=1e-5, self.vision_model = self.get_vit_model_from_radio_config(config).to(
), self.language_model.config.dtype
nn.Linear( )
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
vision_projection_hidden_size, # Construct the vision projection.
bias=False, vit_hidden_size = config.vit_hidden_size
), vision_projection_hidden_size = config.projector_hidden_size
ReLUSquaredActivation(), llm_hidden_size = config.text_config.hidden_size
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
) mlp1 = nn.Sequential(
self.mlp1 = self.mlp1.to(self.language_model.config.dtype) RMSNorm(
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
eps=1e-5,
),
nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
vision_projection_hidden_size,
bias=False,
),
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = mlp1.to(language_model.config.dtype)
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
...@@ -1909,9 +1913,6 @@ class NemotronH_Nano_VL_V2( ...@@ -1909,9 +1913,6 @@ class NemotronH_Nano_VL_V2(
return multimodal_embeddings return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -1921,7 +1922,6 @@ class NemotronH_Nano_VL_V2( ...@@ -1921,7 +1922,6 @@ class NemotronH_Nano_VL_V2(
**kwargs: object, **kwargs: object,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model( hidden_states = self.language_model(
......
...@@ -820,16 +820,18 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -820,16 +820,18 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.encoder = RadioWithNeck( with self._mark_tower_model(vllm_config, "image"):
config=config, quant_config=quant_config, prefix=f"{prefix}.encoder" self.encoder = RadioWithNeck(
) config=config, quant_config=quant_config, prefix=f"{prefix}.encoder"
)
self.decoder = MBartDecoderNoPos( with self._mark_language_model(vllm_config):
config.decoder, self.decoder = MBartDecoderNoPos(
cache_config=cache_config, config.decoder,
quant_config=quant_config, cache_config=cache_config,
prefix=f"{prefix}.decoder", quant_config=quant_config,
) prefix=f"{prefix}.decoder",
)
self.vocab_size = config.decoder.vocab_size self.vocab_size = config.decoder.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -883,9 +885,6 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -883,9 +885,6 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = pixel_values.to(dtype) pixel_values = pixel_values.to(dtype)
return self.encoder(pixel_values) return self.encoder(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -385,20 +385,20 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -385,20 +385,20 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
self.downsample_ratio = config.downsample_ratio self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0] with self._mark_tower_model(vllm_config, "image"):
self.vision_model = self._init_vision_model( self.vision_model = self._init_vision_model(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"), prefix=maybe_prefix(prefix, "vision_model"),
) )
self.mlp1 = self._init_mlp1(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config) with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None self.img_context_token_id = None
...@@ -520,8 +520,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -520,8 +520,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"]) image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"] num_patches = image_input["num_patches"]
...@@ -556,9 +554,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -556,9 +554,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
...@@ -609,7 +604,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -609,7 +604,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
**kwargs: object, **kwargs: object,
) -> IntermediateTensors: ) -> IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
......
...@@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"): if modality.startswith("image"):
return "<image>" return IMAGE_TOKEN
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
...@@ -427,20 +427,22 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -427,20 +427,22 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
prefix=maybe_prefix(prefix, "llm"),
)
self.visual_tokenizer = VisualTokenizer( with self._mark_language_model(vllm_config):
config=config.visual_tokenizer_config, self.llm = init_vllm_registered_model(
quant_config=quant_config, vllm_config=vllm_config.with_hf_config(config.get_text_config()),
prefix=f"{prefix}.visual_tokenizer", prefix=maybe_prefix(prefix, "llm"),
) )
self.vte = VisualEmbedding( with self._mark_tower_model(vllm_config, "image"):
self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size self.visual_tokenizer = VisualTokenizer(
) config=config.visual_tokenizer_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(
self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
)
text_model_type = self.config.get_text_config().model_type text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
...@@ -546,12 +548,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -546,12 +548,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.llm.compute_logits(hidden_states) return self.llm.compute_logits(hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
...@@ -451,6 +451,15 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]) ...@@ -451,6 +451,15 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
dummy_inputs=Ovis2_5DummyInputsBuilder, dummy_inputs=Ovis2_5DummyInputsBuilder,
) )
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return IMAGE_TOKEN
if modality.startswith("video"):
return VIDEO_TOKEN
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -458,20 +467,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -458,20 +467,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "llm"),
)
self.visual_tokenizer = VisualTokenizer( with self._mark_language_model(vllm_config):
config=config.vit_config, self.llm = init_vllm_registered_model(
visual_vocab_size=config.visual_vocab_size, vllm_config=vllm_config.with_hf_config(config.text_config),
multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "llm"),
quant_config=quant_config, )
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
multimodal_config=multimodal_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
text_model_type = self.config.get_text_config().model_type text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
...@@ -650,12 +661,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -650,12 +661,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
logits = self.llm.compute_logits(hidden_states) return self.llm.compute_logits(hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
...@@ -999,6 +999,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -999,6 +999,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
} }
) )
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -1008,22 +1015,24 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1008,22 +1015,24 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.visual = SiglipVisionModel( with self._mark_tower_model(vllm_config, "image"):
config=config.vision_config, self.visual = SiglipVisionModel(
quant_config=quant_config, config=config.vision_config,
multimodal_config=multimodal_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), multimodal_config=multimodal_config,
) prefix=maybe_prefix(prefix, "visual"),
self.mlp_AR = Projector(config, config.vision_config) )
self.mlp_AR = Projector(config, config.vision_config)
self.language_model = Ernie4_5ForCausalLM( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = language_model = Ernie4_5ForCausalLM(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
) prefix=maybe_prefix(prefix, "language_model"),
)
for layer in self.language_model.model.layers: for layer in language_model.model.layers:
if not isinstance(layer, PPMissingLayer): if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = True layer.self_attn.rotary_emb.is_neox_style = True
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -1151,9 +1160,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1151,9 +1160,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
def get_language_model(self) -> nn.Module:
return self.language_model
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> PaddleOCRImagePixelInputs | None: ) -> PaddleOCRImagePixelInputs | None:
...@@ -1180,29 +1186,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1180,29 +1186,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
is_multimodal = kwargs.pop("is_multimodal", None)
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
input_ids = None
return self.language_model( return self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds input_ids, positions, intermediate_tensors, inputs_embeds
) )
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
raise ValueError("Only image modality is supported")
def encode_image( def encode_image(
self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -295,30 +295,32 @@ class PaliGemmaForConditionalGeneration( ...@@ -295,30 +295,32 @@ class PaliGemmaForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim,
)
self.quant_config = quant_config self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim,
)
if config.text_config.model_type == "gemma": if config.text_config.model_type == "gemma":
config.text_config.architectures = ["GemmaForCausalLM"] config.text_config.architectures = ["GemmaForCausalLM"]
else: else:
config.text_config.architectures = ["Gemma2ForCausalLM"] config.text_config.architectures = ["Gemma2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, with self._mark_language_model(vllm_config):
hf_config=config.text_config, self.language_model = language_model = init_vllm_registered_model(
prefix=maybe_prefix(prefix, "language_model"), vllm_config=vllm_config,
) hf_config=config.text_config,
logit_scale = getattr(config, "logit_scale", 1.0) prefix=maybe_prefix(prefix, "language_model"),
self.language_model.logits_processor.scale *= logit_scale )
logit_scale = getattr(config, "logit_scale", 1.0)
language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -367,7 +369,6 @@ class PaliGemmaForConditionalGeneration( ...@@ -367,7 +369,6 @@ class PaliGemmaForConditionalGeneration(
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_tower is not None
pixel_values = image_input["data"] pixel_values = image_input["data"]
image_features = self._image_pixels_to_features( image_features = self._image_pixels_to_features(
self.vision_tower, self.vision_tower,
...@@ -376,9 +377,6 @@ class PaliGemmaForConditionalGeneration( ...@@ -376,9 +377,6 @@ class PaliGemmaForConditionalGeneration(
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -586,31 +586,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -586,31 +586,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
self.embed_tokens = VocabParallelEmbedding( with self._mark_tower_model(vllm_config, "image"):
config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.hidden_size, config.vocab_size,
quant_config=quant_config, config.hidden_size,
prefix=maybe_prefix(prefix, "model.embed_tokens"), quant_config=quant_config,
) prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding(
self.vision_embed_tokens = Phi3HDImageEmbedding( config,
config, quant_config=quant_config,
quant_config=quant_config, multimodal_config=multimodal_config,
multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), )
)
self.language_model = init_vllm_registered_model( with self._mark_language_model(vllm_config):
vllm_config=vllm_config, self.language_model = init_vllm_registered_model(
# The prefix is empty intentionally because default prefix of vllm_config=vllm_config,
# LlamaForCausalLM is "model" # The prefix is empty intentionally because default prefix of
prefix="", # LlamaForCausalLM is "model"
# We don't directly initialize vLLM's LlamaForCausalLM so we prefix="",
# can automatically apply embedding wrapper if this model is # We don't directly initialize vLLM's LlamaForCausalLM so we
# initialized as an embedding model # can automatically apply embedding wrapper if this model is
architectures=["LlamaForCausalLM"], # initialized as an embedding model
) architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
...@@ -652,17 +652,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -652,17 +652,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return image_input["data"] return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens( image_embeds = self.vision_embed_tokens(
image_input["pixel_values"], image_input["image_sizes"] image_input["pixel_values"], image_input["image_sizes"]
) )
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -1027,12 +1027,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1027,12 +1027,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
# Tensor/Pipeline parallel not supported for now. # Tensor/Pipeline parallel not supported for now.
assert get_pp_group().world_size == 1, "pipeline parallel is not supported" assert get_pp_group().world_size == 1, "pipeline parallel is not supported"
self.vision_encoder = Phi4MMImageEncoder( with self._mark_tower_model(vllm_config, {"image", "video"}):
config, self.vision_encoder = Phi4MMImageEncoder(
quant_config, config,
prefix="model.vision_embed_tokens", quant_config,
model_dir=config._name_or_path, prefix="model.vision_embed_tokens",
) model_dir=config._name_or_path,
)
if isinstance(config.embd_layer["audio_embd_layer"], dict): if isinstance(config.embd_layer["audio_embd_layer"], dict):
embedding_config = { embedding_config = {
...@@ -1044,10 +1045,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1044,10 +1045,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"embedding_cls": self.config.embd_layer["embedding_cls"] "embedding_cls": self.config.embd_layer["embedding_cls"]
} }
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) with self._mark_tower_model(vllm_config, "audio"):
self.model = LlamaModel( self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) with self._mark_language_model(vllm_config):
self.model = LlamaModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
...@@ -1245,6 +1249,3 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1245,6 +1249,3 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
connector=["audio_projection_for_vision", "audio_projection"], connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"], tower_model=["vision_encoder", "embed_tokens_extend"],
) )
def get_language_model(self) -> torch.nn.Module:
return self.model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment