Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
......@@ -351,7 +351,7 @@ class NemotronModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -365,7 +365,7 @@ class NemotronModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -491,8 +491,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -548,7 +548,7 @@ class NemotronHModel(nn.Module):
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -562,7 +562,7 @@ class NemotronHModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -823,8 +823,8 @@ class NemotronHForCausalLM(
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -291,7 +291,7 @@ class DeciModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -305,7 +305,7 @@ class DeciModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -461,8 +461,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return DeciModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -561,7 +561,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
......@@ -580,7 +580,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
return multimodal_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -593,9 +593,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -268,7 +268,7 @@ class OlmoModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -285,7 +285,7 @@ class OlmoModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -379,8 +379,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -304,7 +304,7 @@ class Olmo2Model(nn.Module):
["hidden_states"], self.config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -419,8 +419,8 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -296,7 +296,7 @@ class OlmoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -310,7 +310,7 @@ class OlmoeModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -471,8 +471,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -753,7 +753,7 @@ class OpenPanguModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -767,7 +767,7 @@ class OpenPanguModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -969,8 +969,8 @@ class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -100,8 +100,8 @@ class OpenPanguMTP(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -262,7 +262,7 @@ class OPTDecoder(nn.Module):
prefix=f"{prefix}.layers",
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -274,7 +274,7 @@ class OPTDecoder(nn.Module):
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
......@@ -311,8 +311,8 @@ class OPTModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.embed_input_ids(input_ids)
def forward(
self,
......@@ -394,8 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -255,7 +255,7 @@ class OrionModel(nn.Module):
config.hidden_size,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -269,7 +269,7 @@ class OrionModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -345,8 +345,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -361,7 +361,7 @@ class OuroModel(nn.Module):
self.total_ut_steps = getattr(self.config, "total_ut_steps", 4)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -374,7 +374,7 @@ class OuroModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
for current_ut in range(self.total_ut_steps):
residual = None
......@@ -486,8 +486,8 @@ class OuroForCausalLM(nn.Module, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -514,7 +514,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
return tuple(vision_embeddings)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -617,7 +617,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
return modalities
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
......
......@@ -1328,10 +1328,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
vision_embeddings = self.embed_multimodal(**kwargs)
is_multimodal = kwargs.pop("is_multimodal", None)
handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False)
inputs_embeds = self.get_input_embeddings(
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=is_multimodal,
......@@ -1391,7 +1391,7 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
image_embeds = self.mlp_AR(vision_outputs, image_grid_thw)
return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return ()
......
......@@ -375,7 +375,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -270,7 +270,7 @@ class PersimmonModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -284,7 +284,7 @@ class PersimmonModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -347,8 +347,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -240,7 +240,7 @@ class PhiModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -254,7 +254,7 @@ class PhiModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -346,8 +346,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -664,14 +664,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -679,7 +679,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings(
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
......@@ -691,7 +691,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
......
......@@ -1371,7 +1371,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
)
return image_embeds
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment