Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
......@@ -220,8 +220,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module):
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -348,7 +348,7 @@ class SeedOssModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -467,8 +467,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -595,7 +595,7 @@ class SiglipTextTransformer(nn.Module):
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.token_embedding(input_ids)
def forward(
......@@ -1117,7 +1117,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -1130,16 +1130,16 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
)
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -872,14 +872,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -892,9 +892,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -310,7 +310,7 @@ class SolarModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -324,7 +324,7 @@ class SolarModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -478,8 +478,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -260,7 +260,7 @@ class StableLMEpochModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -332,8 +332,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -249,7 +249,7 @@ class Starcoder2Model(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -263,7 +263,7 @@ class Starcoder2Model(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -333,8 +333,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -368,7 +368,7 @@ class Step3TextModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -419,8 +419,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -1075,14 +1075,14 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -1093,9 +1093,9 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......@@ -1113,8 +1113,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id,
......
......@@ -576,7 +576,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......@@ -593,8 +593,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
......
......@@ -57,7 +57,7 @@ class TeleFLMModel(LlamaModel):
if self.use_mup:
self.input_mult = self.config.input_mult
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
if self.use_mup:
embedding = embedding * self.input_mult
......
......@@ -251,7 +251,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler({"plugin": DummyPooler()})
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......
......@@ -385,7 +385,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
_init_parameters(module, dtype)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds *= self.embed_scale
......@@ -416,7 +416,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
and input_ids is not None
and inputs_embeds is None
):
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
input_ids = None
if self.model_config.uses_mrope:
......
......@@ -330,7 +330,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
return LanguageModel(self)
def get_multimodal_embeddings(self, **kwargs):
def embed_multimodal(self, **kwargs):
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values`
......
......@@ -579,14 +579,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
audio_embeddings = self._process_audio_input(audio_input)
return audio_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -597,9 +597,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -474,7 +474,7 @@ def _merge_multimodal_embeddings(
@deprecated(
"`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be "
"`SupportsMultiModal.embed_input_ids` and will be "
"removed in v0.12."
)
def merge_multimodal_embeddings(
......
......@@ -399,7 +399,7 @@ class VoxtralForConditionalGeneration(
return hidden_states
def get_multimodal_embeddings(
def embed_multimodal(
self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
......
......@@ -570,7 +570,7 @@ class WhisperDecoder(nn.Module):
positions: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
):
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions
......@@ -583,7 +583,7 @@ class WhisperDecoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -907,12 +907,12 @@ class WhisperForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -922,7 +922,7 @@ class WhisperForConditionalGeneration(
) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.
return self.model.decoder.get_input_embeddings(input_ids)
return self.model.decoder.embed_input_ids(input_ids)
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
input_features = kwargs.pop("input_features", None)
......
......@@ -756,7 +756,7 @@ class Zamba2Model(nn.Module):
# Final layer normalization
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings.
Args:
......@@ -786,7 +786,7 @@ class Zamba2Model(nn.Module):
"""
# Handle pipeline parallelism for first rank
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = inputs_embeds
# Process through layers
......@@ -930,14 +930,14 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
# Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings.
Args:
input_ids: Tensor of input token IDs
Returns:
Embedded representation of the input tokens
"""
return self.model.get_input_embeddings(input_ids)
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -208,7 +208,7 @@ class PromptUpdateDetails(Generic[_S]):
`None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling
[`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
[`SupportsMultiModal.embed_multimodal`][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
"""
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment