Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -220,8 +220,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -220,8 +220,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.roberta.get_input_embeddings(input_ids) return self.roberta.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module): ...@@ -334,7 +334,7 @@ class SeedOssModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -348,7 +348,7 @@ class SeedOssModel(nn.Module): ...@@ -348,7 +348,7 @@ class SeedOssModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -467,8 +467,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -467,8 +467,8 @@ class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -595,7 +595,7 @@ class SiglipTextTransformer(nn.Module): ...@@ -595,7 +595,7 @@ class SiglipTextTransformer(nn.Module):
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size) self.head = nn.Linear(embed_dim, config.projection_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.token_embedding(input_ids) return self.embeddings.token_embedding(input_ids)
def forward( def forward(
...@@ -1117,7 +1117,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1117,7 +1117,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.text_model return self.text_model
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1130,16 +1130,16 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1130,16 +1130,16 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
) )
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -872,14 +872,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -872,14 +872,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -892,9 +892,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -892,9 +892,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -310,7 +310,7 @@ class SolarModel(nn.Module): ...@@ -310,7 +310,7 @@ class SolarModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -324,7 +324,7 @@ class SolarModel(nn.Module): ...@@ -324,7 +324,7 @@ class SolarModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -478,8 +478,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -478,8 +478,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module): ...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -260,7 +260,7 @@ class StableLMEpochModel(nn.Module): ...@@ -260,7 +260,7 @@ class StableLMEpochModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -332,8 +332,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -332,8 +332,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -249,7 +249,7 @@ class Starcoder2Model(nn.Module): ...@@ -249,7 +249,7 @@ class Starcoder2Model(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -263,7 +263,7 @@ class Starcoder2Model(nn.Module): ...@@ -263,7 +263,7 @@ class Starcoder2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -333,8 +333,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -333,8 +333,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module): ...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -368,7 +368,7 @@ class Step3TextModel(nn.Module): ...@@ -368,7 +368,7 @@ class Step3TextModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -419,8 +419,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -419,8 +419,8 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1075,14 +1075,14 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1075,14 +1075,14 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1093,9 +1093,9 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1093,9 +1093,9 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
...@@ -1113,8 +1113,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1113,8 +1113,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
vision_embeddings, vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id, is_multimodal=input_ids == self.config.image_token_id,
......
...@@ -576,7 +576,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -576,7 +576,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
...@@ -593,8 +593,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -593,8 +593,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
vision_embeddings, vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index, is_multimodal=input_ids == self.config.image_token_index,
......
...@@ -57,7 +57,7 @@ class TeleFLMModel(LlamaModel): ...@@ -57,7 +57,7 @@ class TeleFLMModel(LlamaModel):
if self.use_mup: if self.use_mup:
self.input_mult = self.config.input_mult self.input_mult = self.config.input_mult
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
if self.use_mup: if self.use_mup:
embedding = embedding * self.input_mult embedding = embedding * self.input_mult
......
...@@ -251,7 +251,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -251,7 +251,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler({"plugin": DummyPooler()}) self.pooler = DispatchPooler({"plugin": DummyPooler()})
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
......
...@@ -385,7 +385,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -385,7 +385,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
_init_parameters(module, dtype) _init_parameters(module, dtype)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids) inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None: if self.embed_scale is not None:
inputs_embeds *= self.embed_scale inputs_embeds *= self.embed_scale
...@@ -416,7 +416,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -416,7 +416,7 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
and input_ids is not None and input_ids is not None
and inputs_embeds is None and inputs_embeds is None
): ):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
input_ids = None input_ids = None
if self.model_config.uses_mrope: if self.model_config.uses_mrope:
......
...@@ -330,7 +330,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -330,7 +330,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
return LanguageModel(self) return LanguageModel(self)
def get_multimodal_embeddings(self, **kwargs): def embed_multimodal(self, **kwargs):
pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None)
image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None)
# Model might use `image_patches` instead of `pixel_values` # Model might use `image_patches` instead of `pixel_values`
......
...@@ -579,14 +579,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -579,14 +579,14 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return [] return []
audio_embeddings = self._process_audio_input(audio_input) audio_embeddings = self._process_audio_input(audio_input)
return audio_embeddings return audio_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -597,9 +597,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -597,9 +597,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -474,7 +474,7 @@ def _merge_multimodal_embeddings( ...@@ -474,7 +474,7 @@ def _merge_multimodal_embeddings(
@deprecated( @deprecated(
"`merge_multimodal_embeddings` has been replaced with " "`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be " "`SupportsMultiModal.embed_input_ids` and will be "
"removed in v0.12." "removed in v0.12."
) )
def merge_multimodal_embeddings( def merge_multimodal_embeddings(
......
...@@ -399,7 +399,7 @@ class VoxtralForConditionalGeneration( ...@@ -399,7 +399,7 @@ class VoxtralForConditionalGeneration(
return hidden_states return hidden_states
def get_multimodal_embeddings( def embed_multimodal(
self, **kwargs self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None: ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
......
...@@ -570,7 +570,7 @@ class WhisperDecoder(nn.Module): ...@@ -570,7 +570,7 @@ class WhisperDecoder(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
encoder_hidden_states: torch.Tensor | None, encoder_hidden_states: torch.Tensor | None,
): ):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
positions = self.embed_positions(positions) positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
...@@ -583,7 +583,7 @@ class WhisperDecoder(nn.Module): ...@@ -583,7 +583,7 @@ class WhisperDecoder(nn.Module):
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -907,12 +907,12 @@ class WhisperForConditionalGeneration( ...@@ -907,12 +907,12 @@ class WhisperForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model.decoder return self.model.decoder
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface. # Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])] return [self.model.get_encoder_outputs(audio_input["input_features"])]
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -922,7 +922,7 @@ class WhisperForConditionalGeneration( ...@@ -922,7 +922,7 @@ class WhisperForConditionalGeneration(
) -> torch.Tensor: ) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since # This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. # Whisper does not have encoder text tokens.
return self.model.decoder.get_input_embeddings(input_ids) return self.model.decoder.embed_input_ids(input_ids)
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs: def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
input_features = kwargs.pop("input_features", None) input_features = kwargs.pop("input_features", None)
......
...@@ -756,7 +756,7 @@ class Zamba2Model(nn.Module): ...@@ -756,7 +756,7 @@ class Zamba2Model(nn.Module):
# Final layer normalization # Final layer normalization
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings. """Convert input token IDs to embeddings.
Args: Args:
...@@ -786,7 +786,7 @@ class Zamba2Model(nn.Module): ...@@ -786,7 +786,7 @@ class Zamba2Model(nn.Module):
""" """
# Handle pipeline parallelism for first rank # Handle pipeline parallelism for first rank
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Process through layers # Process through layers
...@@ -930,14 +930,14 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -930,14 +930,14 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
# Initialize logits processing and sampling # Initialize logits processing and sampling
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Convert input token IDs to embeddings. """Convert input token IDs to embeddings.
Args: Args:
input_ids: Tensor of input token IDs input_ids: Tensor of input token IDs
Returns: Returns:
Embedded representation of the input tokens Embedded representation of the input tokens
""" """
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -208,7 +208,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -208,7 +208,7 @@ class PromptUpdateDetails(Generic[_S]):
`None` (default) means to assign embeddings to all positions of `full`. `None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling The embeddings are obtained by calling
[`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings]. [`SupportsMultiModal.embed_multimodal`][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal].
""" """
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment