Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -550,8 +550,8 @@ class Idefics3Model(nn.Module): ...@@ -550,8 +550,8 @@ class Idefics3Model(nn.Module):
return image_hidden_states return image_hidden_states
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.text_model.get_input_embeddings(input_ids) return self.text_model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
...@@ -674,7 +674,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo ...@@ -674,7 +674,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -94,7 +94,7 @@ class SupportsMultiModal(Protocol): ...@@ -94,7 +94,7 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
...@@ -104,7 +104,13 @@ class SupportsMultiModal(Protocol): ...@@ -104,7 +104,13 @@ class SupportsMultiModal(Protocol):
the appearances of their corresponding multimodal data item in the the appearances of their corresponding multimodal data item in the
input prompt. input prompt.
""" """
... if hasattr(self, "get_multimodal_embeddings"):
logger.warning_once(
"`get_multimodal_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_multimodal`."
)
return self.get_multimodal_embeddings(**kwargs)
def get_language_model(self) -> VllmModel: def get_language_model(self) -> VllmModel:
""" """
...@@ -119,10 +125,10 @@ class SupportsMultiModal(Protocol): ...@@ -119,10 +125,10 @@ class SupportsMultiModal(Protocol):
... ...
@overload @overload
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ... def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
@overload @overload
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings, multimodal_embeddings: MultiModalEmbeddings,
...@@ -131,17 +137,17 @@ class SupportsMultiModal(Protocol): ...@@ -131,17 +137,17 @@ class SupportsMultiModal(Protocol):
handle_oov_mm_token: bool = False, handle_oov_mm_token: bool = False,
) -> Tensor: ... ) -> Tensor: ...
def _get_text_embeddings( def _embed_text_input_ids(
self, self,
input_ids: Tensor, input_ids: Tensor,
get_input_embeddings: Callable[[Tensor], Tensor], embed_input_ids: Callable[[Tensor], Tensor],
*, *,
is_multimodal: Tensor | None, is_multimodal: Tensor | None,
handle_oov_mm_token: bool, handle_oov_mm_token: bool,
) -> Tensor: ) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None: if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text]) text_embeds = embed_input_ids(input_ids[is_text])
return torch.empty( return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]), (input_ids.shape[0], text_embeds.shape[1]),
...@@ -149,9 +155,9 @@ class SupportsMultiModal(Protocol): ...@@ -149,9 +155,9 @@ class SupportsMultiModal(Protocol):
device=text_embeds.device, device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds) ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids) return embed_input_ids(input_ids)
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -167,15 +173,15 @@ class SupportsMultiModal(Protocol): ...@@ -167,15 +173,15 @@ class SupportsMultiModal(Protocol):
In case the multi-modal token IDs exceed the vocabulary size of In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False` the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method to avoid calling the language model's `embed_input_ids` method
on those tokens. Note however that doing so increases memory usage on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings. as an additional buffer is needed to hold the input embeddings.
""" """
from .utils import _merge_multimodal_embeddings from .utils import _merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.get_language_model().get_input_embeddings, self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
...@@ -185,7 +191,7 @@ class SupportsMultiModal(Protocol): ...@@ -185,7 +191,7 @@ class SupportsMultiModal(Protocol):
if is_multimodal is None: if is_multimodal is None:
raise ValueError( raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, " "`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to " "please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229." "https://github.com/vllm-project/vllm/pull/16229."
) )
......
...@@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) ...@@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
class VllmModel(Protocol[T_co]): class VllmModel(Protocol[T_co]):
"""The interface required for all models in vLLM.""" """The interface required for all models in vLLM."""
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None: ...
def get_input_embeddings( def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
"""Apply token embeddings to `input_ids`.""" """Apply token embeddings to `input_ids`."""
... if hasattr(self, "get_input_embeddings"):
logger.warning_once(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
return self.get_input_embeddings(input_ids)
def forward( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> T_co: ...
def _check_vllm_model_init(model: type[object] | object) -> bool: def _check_vllm_model_init(model: type[object] | object) -> bool:
...@@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool: ...@@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
return supports_kw(model_init, "vllm_config") return supports_kw(model_init, "vllm_config")
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool: def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
model_get_input_embeddings = getattr(model, "get_input_embeddings", None) model_embed_input_ids = getattr(model, "embed_input_ids", None)
if not callable(model_get_input_embeddings): if not callable(model_embed_input_ids):
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
if callable(model_get_input_embeddings):
logger.warning(
"`get_input_embeddings` for vLLM models is deprecated and will be "
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
"this method to `embed_input_ids`."
)
model.embed_input_ids = model_get_input_embeddings
logger.warning( logger.warning(
"The model (%s) is missing the `get_input_embeddings` method.", "The model (%s) is missing the `embed_input_ids` method.",
model, model,
) )
return False return False
...@@ -110,7 +113,7 @@ def is_vllm_model( ...@@ -110,7 +113,7 @@ def is_vllm_model(
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]: ) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
return ( return (
_check_vllm_model_init(model) _check_vllm_model_init(model)
and _check_vllm_model_get_input_embeddings(model) and _check_vllm_model_embed_input_ids(model)
and _check_vllm_model_forward(model) and _check_vllm_model_forward(model)
) )
......
...@@ -284,7 +284,7 @@ class InternLM2Model(nn.Module): ...@@ -284,7 +284,7 @@ class InternLM2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids) return self.tok_embeddings(input_ids)
def forward( def forward(
...@@ -298,7 +298,7 @@ class InternLM2Model(nn.Module): ...@@ -298,7 +298,7 @@ class InternLM2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -350,8 +350,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -350,8 +350,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -742,7 +742,7 @@ class InternS1ForConditionalGeneration( ...@@ -742,7 +742,7 @@ class InternS1ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
...@@ -765,7 +765,7 @@ class InternS1ForConditionalGeneration( ...@@ -765,7 +765,7 @@ class InternS1ForConditionalGeneration(
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -778,9 +778,9 @@ class InternS1ForConditionalGeneration( ...@@ -778,9 +778,9 @@ class InternS1ForConditionalGeneration(
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -1344,7 +1344,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1344,7 +1344,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
...@@ -1367,7 +1367,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1367,7 +1367,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1380,9 +1380,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1380,9 +1380,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -275,7 +275,7 @@ class JAISModel(nn.Module): ...@@ -275,7 +275,7 @@ class JAISModel(nn.Module):
["hidden_states"], config.n_embd ["hidden_states"], config.n_embd
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -287,7 +287,7 @@ class JAISModel(nn.Module): ...@@ -287,7 +287,7 @@ class JAISModel(nn.Module):
) -> IntermediateTensors | torch.Tensor: ) -> IntermediateTensors | torch.Tensor:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
if self.wpe is not None: if self.wpe is not None:
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
...@@ -339,8 +339,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -339,8 +339,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -340,7 +340,7 @@ class JambaModel(nn.Module): ...@@ -340,7 +340,7 @@ class JambaModel(nn.Module):
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -354,7 +354,7 @@ class JambaModel(nn.Module): ...@@ -354,7 +354,7 @@ class JambaModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -508,8 +508,8 @@ class JambaForCausalLM( ...@@ -508,8 +508,8 @@ class JambaForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1484,9 +1484,7 @@ class BaseKeyeModule(nn.Module): ...@@ -1484,9 +1484,7 @@ class BaseKeyeModule(nn.Module):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return None return None
......
...@@ -439,7 +439,7 @@ class KimiLinearModel(nn.Module): ...@@ -439,7 +439,7 @@ class KimiLinearModel(nn.Module):
"num_attention_heads must be divisible by world_size" "num_attention_heads must be divisible by world_size"
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -454,7 +454,7 @@ class KimiLinearModel(nn.Module): ...@@ -454,7 +454,7 @@ class KimiLinearModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -504,8 +504,8 @@ class KimiLinearForCausalLM( ...@@ -504,8 +504,8 @@ class KimiLinearForCausalLM(
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -404,7 +404,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -404,7 +404,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> NestedTensors | None: def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
# Validate the multimodal input keyword arguments # Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
......
...@@ -351,7 +351,7 @@ class Lfm2Model(nn.Module): ...@@ -351,7 +351,7 @@ class Lfm2Model(nn.Module):
else: else:
self.embedding_norm = PPMissingLayer() self.embedding_norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -365,7 +365,7 @@ class Lfm2Model(nn.Module): ...@@ -365,7 +365,7 @@ class Lfm2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -504,8 +504,8 @@ class Lfm2ForCausalLM( ...@@ -504,8 +504,8 @@ class Lfm2ForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -466,7 +466,7 @@ class Lfm2MoeModel(nn.Module): ...@@ -466,7 +466,7 @@ class Lfm2MoeModel(nn.Module):
else: else:
self.embedding_norm = PPMissingLayer() self.embedding_norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -480,7 +480,7 @@ class Lfm2MoeModel(nn.Module): ...@@ -480,7 +480,7 @@ class Lfm2MoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -714,8 +714,8 @@ class Lfm2MoeForCausalLM( ...@@ -714,8 +714,8 @@ class Lfm2MoeForCausalLM(
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
......
...@@ -424,7 +424,7 @@ class LlamaModel(nn.Module): ...@@ -424,7 +424,7 @@ class LlamaModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -438,7 +438,7 @@ class LlamaModel(nn.Module): ...@@ -438,7 +438,7 @@ class LlamaModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -640,8 +640,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -640,8 +640,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
): ):
return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -82,7 +82,7 @@ class LlamaModel(nn.Module): ...@@ -82,7 +82,7 @@ class LlamaModel(nn.Module):
) )
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -93,7 +93,7 @@ class LlamaModel(nn.Module): ...@@ -93,7 +93,7 @@ class LlamaModel(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None residual = None
for layer in self.layers: for layer in self.layers:
...@@ -195,7 +195,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -195,7 +195,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
def forward( def forward(
self, self,
......
...@@ -84,7 +84,7 @@ class LlamaModel(nn.Module): ...@@ -84,7 +84,7 @@ class LlamaModel(nn.Module):
self.config.hidden_size * 2, self.config.hidden_size, bias=False self.config.hidden_size * 2, self.config.hidden_size, bias=False
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -158,8 +158,8 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): ...@@ -158,8 +158,8 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -172,7 +172,7 @@ class LlamaModel(nn.Module): ...@@ -172,7 +172,7 @@ class LlamaModel(nn.Module):
eps=self.config.rms_norm_eps, eps=self.config.rms_norm_eps,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -183,7 +183,7 @@ class LlamaModel(nn.Module): ...@@ -183,7 +183,7 @@ class LlamaModel(nn.Module):
input_embeds: torch.Tensor | None = None, input_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if input_embeds is None: if input_embeds is None:
input_embeds = self.get_input_embeddings(input_ids) input_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1] assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None residual = None
...@@ -261,13 +261,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -261,13 +261,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
requires_grad=False, requires_grad=False,
) )
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None, multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -661,7 +661,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -661,7 +661,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -483,14 +483,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -483,14 +483,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -501,9 +501,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -501,9 +501,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -422,7 +422,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -422,7 +422,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None: if video_input is None:
return [] return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment