Unverified Commit fda3f03e authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[4/N] Initialize MM components in context managers (M-P) (#32663)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb917203
......@@ -621,14 +621,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None:
multimodal_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model(
input_ids,
......
......@@ -791,14 +791,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
)
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,
......
......@@ -71,8 +71,6 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
class LMMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def make_empty_intermediate_tensors(self, *args, **kwargs):
raise RuntimeError("This module should not be called in MM encoder-only mode")
......@@ -81,8 +79,6 @@ class LMMissingLayer(nn.Module):
class TowerMissingLayer(nn.Module):
packed_modules_mapping: dict[str, list[str]] = {}
def __init__(self, modalities: set[str] | str) -> None:
if isinstance(modalities, str):
modalities = {modalities}
......@@ -92,7 +88,10 @@ class TowerMissingLayer(nn.Module):
self.modalities = modalities
def __call__(self, *args, **kwargs):
raise RuntimeError(f"The following modalities are disabled: {self.modalities}")
raise RuntimeError(
f"This module should not be called when the following "
f"modalities are disabled: {self.modalities}"
)
@contextmanager
......
......@@ -789,7 +789,6 @@ class InternS1ForConditionalGeneration(
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
forward_kwargs = {
......
......@@ -1379,7 +1379,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
forward_kwargs = {
......
......@@ -707,8 +707,9 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# Initialize audio components
with self._mark_tower_model(vllm_config, "audio"):
self.audio_encoder = DashengAudioTransformer(
config.audio_encoder_config,
quant_config=quant_config,
......@@ -722,7 +723,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "audio_projector"),
)
# Initialize language model (decoder)
with self._mark_language_model(vllm_config):
self.decoder = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
......@@ -730,7 +731,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
architectures=["Qwen2ForCausalLM"],
)
self.quant_config = quant_config
self.make_empty_intermediate_tensors = (
self.decoder.make_empty_intermediate_tensors
)
......@@ -787,9 +787,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
return torch.split(masked_audio_features, audio_output_lengths.tolist())
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
......
......@@ -553,6 +553,8 @@ class MiniCPMO(MiniCPMV2_6):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
with self._mark_tower_model(vllm_config, "audio"):
self.apm = self.init_audio_module(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
)
......
......@@ -1028,16 +1028,18 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
with self._mark_language_model(vllm_config):
self.llm = self.init_llm(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm")
)
self.vpm = self.init_vision_module(
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vpm = vpm = self.init_vision_module(
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
)
self.vision_dim = (
self.vpm.embed_dim
if self.version == (2, 0)
else self.vpm.embeddings.embed_dim
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim
)
self.embed_dim = self.config.hidden_size
......@@ -1134,9 +1136,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
......
......@@ -201,7 +201,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
......@@ -217,12 +217,17 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_feature_layer = config.vision_feature_layer
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = -1
......@@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.language_model.make_empty_intermediate_tensors
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _image_pixels_to_features(
self,
vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
......@@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self,
inputs: MiniMaxVL01ImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
......@@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor):
......@@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
......
......@@ -1441,15 +1441,20 @@ class MolmoForCausalLM(
self.multimodal_config = multimodal_config
vision_config = VisionBackboneConfig()
with self._mark_tower_model(vllm_config, "image"):
self.vision_backbone = MolmoVisionBackbone(
config,
vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
with self._mark_language_model(vllm_config):
self.model = MolmoModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.img_patch_id = None
if self.config.weight_tying:
......@@ -1525,9 +1530,6 @@ class MolmoForCausalLM(
results.append(feats[is_valid][order])
return results
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -2514,12 +2514,15 @@ class Molmo2ForConditionalGeneration(
kwargs[field.name] = getattr(config.adapter_config, field.name)
adapter_config = AdapterConfig(**kwargs)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_backbone = Molmo2VisionBackbone(
vit_config,
adapter_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
with self._mark_language_model(vllm_config):
self.model = Molmo2TextModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
......@@ -2687,9 +2690,6 @@ class Molmo2ForConditionalGeneration(
out.append(out_features)
return tuple(out)
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
......
......@@ -1511,11 +1511,15 @@ class NemotronH_Nano_VL_V2(
self.ps_version = config.ps_version
self.image_tag_type = config.image_tag_type
self.video_pruning_rate = multimodal_config.video_pruning_rate
self.language_model = init_vllm_registered_model(
with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.dtype
)
......@@ -1525,7 +1529,7 @@ class NemotronH_Nano_VL_V2(
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
mlp1 = nn.Sequential(
RMSNorm(
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
eps=1e-5,
......@@ -1538,7 +1542,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = self.mlp1.to(self.language_model.config.dtype)
self.mlp1 = mlp1.to(language_model.config.dtype)
self.config = config
self.model_config = vllm_config.model_config
......@@ -1909,9 +1913,6 @@ class NemotronH_Nano_VL_V2(
return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
......@@ -1921,7 +1922,6 @@ class NemotronH_Nano_VL_V2(
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
hidden_states = self.language_model(
......
......@@ -820,10 +820,12 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
with self._mark_tower_model(vllm_config, "image"):
self.encoder = RadioWithNeck(
config=config, quant_config=quant_config, prefix=f"{prefix}.encoder"
)
with self._mark_language_model(vllm_config):
self.decoder = MBartDecoderNoPos(
config.decoder,
cache_config=cache_config,
......@@ -883,9 +885,6 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = pixel_values.to(dtype)
return self.encoder(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -385,21 +385,21 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.visual_token_mask = None
......@@ -520,8 +520,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
......@@ -556,9 +554,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
......@@ -609,7 +604,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
forward_kwargs = {
......
......@@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
return IMAGE_TOKEN
raise ValueError("Only image modality is supported")
......@@ -427,17 +427,19 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
quant_config = vllm_config.quant_config
self.config: PretrainedConfig = config
with self._mark_language_model(vllm_config):
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
prefix=maybe_prefix(prefix, "llm"),
)
with self._mark_tower_model(vllm_config, "image"):
self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(
self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size
)
......@@ -546,12 +548,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.llm.compute_logits(hidden_states)
return logits
return self.llm.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
......@@ -451,6 +451,15 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo])
dummy_inputs=Ovis2_5DummyInputsBuilder,
)
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return IMAGE_TOKEN
if modality.startswith("video"):
return VIDEO_TOKEN
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -458,11 +467,14 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_config = vllm_config.model_config.multimodal_config
self.config: PretrainedConfig = config
with self._mark_language_model(vllm_config):
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "llm"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
......@@ -470,7 +482,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size)
text_model_type = self.config.get_text_config().model_type
......@@ -650,12 +661,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.llm.compute_logits(hidden_states)
return logits
return self.llm.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
......@@ -999,6 +999,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -1008,6 +1015,7 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, "image"):
self.visual = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
......@@ -1016,12 +1024,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
)
self.mlp_AR = Projector(config, config.vision_config)
self.language_model = Ernie4_5ForCausalLM(
with self._mark_language_model(vllm_config):
self.language_model = language_model = Ernie4_5ForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
for layer in self.language_model.model.layers:
for layer in language_model.model.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.rotary_emb.is_neox_style = True
......@@ -1151,9 +1160,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
return llm_positions, mrope_position_delta
def get_language_model(self) -> nn.Module:
return self.language_model
def _parse_and_validate_image_input(
self, **kwargs: object
) -> PaddleOCRImagePixelInputs | None:
......@@ -1180,29 +1186,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
is_multimodal = kwargs.pop("is_multimodal", None)
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
input_ids = None
return self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
raise ValueError("Only image modality is supported")
def encode_image(
self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor
) -> torch.Tensor:
......
......@@ -295,7 +295,9 @@ class PaliGemmaForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
......@@ -306,19 +308,19 @@ class PaliGemmaForConditionalGeneration(
projection_dim=config.vision_config.projection_dim,
)
self.quant_config = quant_config
if config.text_config.model_type == "gemma":
config.text_config.architectures = ["GemmaForCausalLM"]
else:
config.text_config.architectures = ["Gemma2ForCausalLM"]
self.language_model = init_vllm_registered_model(
with self._mark_language_model(vllm_config):
self.language_model = language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
......@@ -367,7 +369,6 @@ class PaliGemmaForConditionalGeneration(
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower,
......@@ -376,9 +377,6 @@ class PaliGemmaForConditionalGeneration(
return self.multi_modal_projector(image_features)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -586,14 +586,13 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID
with self._mark_tower_model(vllm_config, "image"):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
quant_config=quant_config,
......@@ -601,6 +600,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# The prefix is empty intentionally because default prefix of
......@@ -652,17 +652,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(
image_input["pixel_values"], image_input["image_sizes"]
)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -1027,6 +1027,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
# Tensor/Pipeline parallel not supported for now.
assert get_pp_group().world_size == 1, "pipeline parallel is not supported"
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_encoder = Phi4MMImageEncoder(
config,
quant_config,
......@@ -1044,7 +1045,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"embedding_cls": self.config.embd_layer["embedding_cls"]
}
with self._mark_tower_model(vllm_config, "audio"):
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
with self._mark_language_model(vllm_config):
self.model = LlamaModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
......@@ -1245,6 +1249,3 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
def get_language_model(self) -> torch.nn.Module:
return self.model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment