"wrappers/python/vscode:/vscode.git/clone" did not exist on "1f12f28671f167d9b782a904f78aef942d942290"
Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
......@@ -550,8 +550,8 @@ class Idefics3Model(nn.Module):
return image_hidden_states
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.text_model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.text_model.embed_input_ids(input_ids)
def forward(
self,
......@@ -674,7 +674,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -94,7 +94,7 @@ class SupportsMultiModal(Protocol):
"""
...
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
......@@ -104,7 +104,13 @@ class SupportsMultiModal(Protocol):
the appearances of their corresponding multimodal data item in the
input prompt.
"""
...
if hasattr(self, "get_multimodal_embeddings"):
logger.warning_once(
"`get_multimodal_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_multimodal`."
)
return self.get_multimodal_embeddings(**kwargs)
def get_language_model(self) -> VllmModel:
"""
......@@ -119,10 +125,10 @@ class SupportsMultiModal(Protocol):
...
@overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ...
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
@overload
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings,
......@@ -131,17 +137,17 @@ class SupportsMultiModal(Protocol):
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _get_text_embeddings(
def _embed_text_input_ids(
self,
input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor],
embed_input_ids: Callable[[Tensor], Tensor],
*,
is_multimodal: Tensor | None,
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
text_embeds = embed_input_ids(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
......@@ -149,9 +155,9 @@ class SupportsMultiModal(Protocol):
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
return embed_input_ids(input_ids)
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -167,15 +173,15 @@ class SupportsMultiModal(Protocol):
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
to avoid calling the language model's `embed_input_ids` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().get_input_embeddings,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
......@@ -185,7 +191,7 @@ class SupportsMultiModal(Protocol):
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
......
......@@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
class VllmModel(Protocol[T_co]):
"""The interface required for all models in vLLM."""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None: ...
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Apply token embeddings to `input_ids`."""
...
if hasattr(self, "get_input_embeddings"):
logger.warning_once(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
return self.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> T_co: ...
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
def _check_vllm_model_init(model: type[object] | object) -> bool:
......@@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if not callable(model_get_input_embeddings):
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
model_embed_input_ids = getattr(model, "embed_input_ids", None)
if not callable(model_embed_input_ids):
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if callable(model_get_input_embeddings):
logger.warning(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
model.embed_input_ids = model_get_input_embeddings
logger.warning(
"The model (%s) is missing the `get_input_embeddings` method.",
"The model (%s) is missing the `embed_input_ids` method.",
model,
)
return False
......@@ -110,7 +113,7 @@ def is_vllm_model(
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
return (
_check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model)
and _check_vllm_model_embed_input_ids(model)
and _check_vllm_model_forward(model)
)
......
......@@ -284,7 +284,7 @@ class InternLM2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids)
def forward(
......@@ -298,7 +298,7 @@ class InternLM2Model(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -350,8 +350,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -742,7 +742,7 @@ class InternS1ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
......@@ -765,7 +765,7 @@ class InternS1ForConditionalGeneration(
return multimodal_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -778,9 +778,9 @@ class InternS1ForConditionalGeneration(
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -1344,7 +1344,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
......@@ -1367,7 +1367,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
return multimodal_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -1380,9 +1380,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -275,7 +275,7 @@ class JAISModel(nn.Module):
["hidden_states"], config.n_embd
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
......@@ -287,7 +287,7 @@ class JAISModel(nn.Module):
) -> IntermediateTensors | torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
......@@ -339,8 +339,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -340,7 +340,7 @@ class JambaModel(nn.Module):
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -354,7 +354,7 @@ class JambaModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -508,8 +508,8 @@ class JambaForCausalLM(
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -1484,9 +1484,7 @@ class BaseKeyeModule(nn.Module):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
......
......@@ -439,7 +439,7 @@ class KimiLinearModel(nn.Module):
"num_attention_heads must be divisible by world_size"
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -454,7 +454,7 @@ class KimiLinearModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -504,8 +504,8 @@ class KimiLinearForCausalLM(
self.config.vocab_size, scale=logit_scale
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -404,7 +404,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> NestedTensors | None:
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
......
......@@ -351,7 +351,7 @@ class Lfm2Model(nn.Module):
else:
self.embedding_norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -365,7 +365,7 @@ class Lfm2Model(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -504,8 +504,8 @@ class Lfm2ForCausalLM(
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -466,7 +466,7 @@ class Lfm2MoeModel(nn.Module):
else:
self.embedding_norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -480,7 +480,7 @@ class Lfm2MoeModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -714,8 +714,8 @@ class Lfm2MoeForCausalLM(
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def update_physical_experts_metadata(
self,
......
......@@ -424,7 +424,7 @@ class LlamaModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -438,7 +438,7 @@ class LlamaModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -640,8 +640,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
):
return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -82,7 +82,7 @@ class LlamaModel(nn.Module):
)
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -93,7 +93,7 @@ class LlamaModel(nn.Module):
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
......@@ -195,7 +195,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def get_language_model(self) -> torch.nn.Module:
return self.model
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore
embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
def forward(
self,
......
......@@ -84,7 +84,7 @@ class LlamaModel(nn.Module):
self.config.hidden_size * 2, self.config.hidden_size, bias=False
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -158,8 +158,8 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
self.config.vocab_size, scale=logit_scale
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -172,7 +172,7 @@ class LlamaModel(nn.Module):
eps=self.config.rms_norm_eps,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -183,7 +183,7 @@ class LlamaModel(nn.Module):
input_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_embeds is None:
input_embeds = self.get_input_embeddings(input_ids)
input_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None
......@@ -261,13 +261,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False,
)
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -661,7 +661,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -483,14 +483,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -501,9 +501,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -422,7 +422,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment