Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -351,7 +351,7 @@ class NemotronModel(nn.Module): ...@@ -351,7 +351,7 @@ class NemotronModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -365,7 +365,7 @@ class NemotronModel(nn.Module): ...@@ -365,7 +365,7 @@ class NemotronModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -491,8 +491,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -491,8 +491,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -548,7 +548,7 @@ class NemotronHModel(nn.Module): ...@@ -548,7 +548,7 @@ class NemotronHModel(nn.Module):
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -562,7 +562,7 @@ class NemotronHModel(nn.Module): ...@@ -562,7 +562,7 @@ class NemotronHModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -823,8 +823,8 @@ class NemotronHForCausalLM( ...@@ -823,8 +823,8 @@ class NemotronHForCausalLM(
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -291,7 +291,7 @@ class DeciModel(nn.Module): ...@@ -291,7 +291,7 @@ class DeciModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -305,7 +305,7 @@ class DeciModel(nn.Module): ...@@ -305,7 +305,7 @@ class DeciModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -461,8 +461,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): ...@@ -461,8 +461,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return DeciModel(vllm_config=vllm_config, prefix=prefix) return DeciModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -561,7 +561,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -561,7 +561,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
...@@ -580,7 +580,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -580,7 +580,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -593,9 +593,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -593,9 +593,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -268,7 +268,7 @@ class OlmoModel(nn.Module): ...@@ -268,7 +268,7 @@ class OlmoModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -285,7 +285,7 @@ class OlmoModel(nn.Module): ...@@ -285,7 +285,7 @@ class OlmoModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -379,8 +379,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -379,8 +379,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -304,7 +304,7 @@ class Olmo2Model(nn.Module): ...@@ -304,7 +304,7 @@ class Olmo2Model(nn.Module):
["hidden_states"], self.config.hidden_size ["hidden_states"], self.config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -419,8 +419,8 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -419,8 +419,8 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -296,7 +296,7 @@ class OlmoeModel(nn.Module): ...@@ -296,7 +296,7 @@ class OlmoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -310,7 +310,7 @@ class OlmoeModel(nn.Module): ...@@ -310,7 +310,7 @@ class OlmoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -471,8 +471,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -471,8 +471,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -753,7 +753,7 @@ class OpenPanguModel(nn.Module): ...@@ -753,7 +753,7 @@ class OpenPanguModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -767,7 +767,7 @@ class OpenPanguModel(nn.Module): ...@@ -767,7 +767,7 @@ class OpenPanguModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -969,8 +969,8 @@ class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): ...@@ -969,8 +969,8 @@ class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -100,8 +100,8 @@ class OpenPanguMTP(nn.Module, SupportsPP): ...@@ -100,8 +100,8 @@ class OpenPanguMTP(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -262,7 +262,7 @@ class OPTDecoder(nn.Module): ...@@ -262,7 +262,7 @@ class OPTDecoder(nn.Module):
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -274,7 +274,7 @@ class OPTDecoder(nn.Module): ...@@ -274,7 +274,7 @@ class OPTDecoder(nn.Module):
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
if self.project_in is not None: if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds) inputs_embeds, _ = self.project_in(inputs_embeds)
...@@ -311,8 +311,8 @@ class OPTModel(nn.Module): ...@@ -311,8 +311,8 @@ class OPTModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids) return self.decoder.embed_input_ids(input_ids)
def forward( def forward(
self, self,
...@@ -394,8 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -394,8 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -255,7 +255,7 @@ class OrionModel(nn.Module): ...@@ -255,7 +255,7 @@ class OrionModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -269,7 +269,7 @@ class OrionModel(nn.Module): ...@@ -269,7 +269,7 @@ class OrionModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -345,8 +345,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -345,8 +345,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -361,7 +361,7 @@ class OuroModel(nn.Module): ...@@ -361,7 +361,7 @@ class OuroModel(nn.Module):
self.total_ut_steps = getattr(self.config, "total_ut_steps", 4) self.total_ut_steps = getattr(self.config, "total_ut_steps", 4)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -374,7 +374,7 @@ class OuroModel(nn.Module): ...@@ -374,7 +374,7 @@ class OuroModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
for current_ut in range(self.total_ut_steps): for current_ut in range(self.total_ut_steps):
residual = None residual = None
...@@ -486,8 +486,8 @@ class OuroForCausalLM(nn.Module, SupportsLoRA): ...@@ -486,8 +486,8 @@ class OuroForCausalLM(nn.Module, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -514,7 +514,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -514,7 +514,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
return tuple(vision_embeddings) return tuple(vision_embeddings)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -617,7 +617,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -617,7 +617,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
return modalities return modalities
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
......
...@@ -1328,10 +1328,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1328,10 +1328,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.embed_multimodal(**kwargs)
is_multimodal = kwargs.pop("is_multimodal", None) is_multimodal = kwargs.pop("is_multimodal", None)
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False) handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
vision_embeddings, vision_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
...@@ -1391,7 +1391,7 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1391,7 +1391,7 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
image_embeds = self.mlp_AR(vision_outputs, image_grid_thw) image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)
return image_embeds return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return () return ()
......
...@@ -375,7 +375,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -375,7 +375,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -270,7 +270,7 @@ class PersimmonModel(nn.Module): ...@@ -270,7 +270,7 @@ class PersimmonModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -284,7 +284,7 @@ class PersimmonModel(nn.Module): ...@@ -284,7 +284,7 @@ class PersimmonModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -347,8 +347,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): ...@@ -347,8 +347,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -240,7 +240,7 @@ class PhiModel(nn.Module): ...@@ -240,7 +240,7 @@ class PhiModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -254,7 +254,7 @@ class PhiModel(nn.Module): ...@@ -254,7 +254,7 @@ class PhiModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -346,8 +346,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -346,8 +346,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -664,14 +664,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -664,14 +664,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -679,7 +679,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -679,7 +679,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False, handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.embed_tokens, self.embed_tokens,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
...@@ -691,7 +691,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -691,7 +691,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if is_multimodal is None: if is_multimodal is None:
raise ValueError( raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, " "`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to " "please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229." "https://github.com/vllm-project/vllm/pull/16229."
) )
......
...@@ -1371,7 +1371,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1371,7 +1371,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
) )
return image_embeds return image_embeds
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment