Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -56,13 +56,13 @@ The initialization code should look like this: ...@@ -56,13 +56,13 @@ The initialization code should look like this:
### Computation Code ### Computation Code
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model. - Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
```python ```python
class MyModel(nn.Module): class MyModel(nn.Module):
... ...
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
... ...
``` ```
......
...@@ -36,7 +36,7 @@ Further update the model as follows: ...@@ -36,7 +36,7 @@ Further update the model as follows:
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it. More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. - Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
??? code ??? code
...@@ -49,7 +49,7 @@ Further update the model as follows: ...@@ -49,7 +49,7 @@ Further update the model as follows:
image_features = self.vision_encoder(image_input) image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_multimodal_embeddings( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
) -> MultiModalEmbeddings | None: ) -> MultiModalEmbeddings | None:
...@@ -69,7 +69,7 @@ Further update the model as follows: ...@@ -69,7 +69,7 @@ Further update the model as follows:
!!! note !!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing. [PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings]. This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].
You may override this method if additional logic is required for your model when merging embeddings. You may override this method if additional logic is required for your model when merging embeddings.
......
...@@ -382,7 +382,7 @@ class ApertusModel(nn.Module): ...@@ -382,7 +382,7 @@ class ApertusModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -396,7 +396,7 @@ class ApertusModel(nn.Module): ...@@ -396,7 +396,7 @@ class ApertusModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -557,8 +557,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -557,8 +557,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -239,7 +239,7 @@ class ArceeModel(nn.Module): ...@@ -239,7 +239,7 @@ class ArceeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -254,7 +254,7 @@ class ArceeModel(nn.Module): ...@@ -254,7 +254,7 @@ class ArceeModel(nn.Module):
hidden_states = ( hidden_states = (
inputs_embeds inputs_embeds
if inputs_embeds is not None if inputs_embeds is not None
else self.get_input_embeddings(input_ids) else self.embed_input_ids(input_ids)
) )
residual = None residual = None
else: else:
...@@ -423,8 +423,8 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -423,8 +423,8 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights into the model (delegates to inner model and handles """Load weights into the model (delegates to inner model and handles
......
...@@ -442,7 +442,7 @@ class ArcticModel(nn.Module): ...@@ -442,7 +442,7 @@ class ArcticModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -456,7 +456,7 @@ class ArcticModel(nn.Module): ...@@ -456,7 +456,7 @@ class ArcticModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -496,8 +496,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -496,8 +496,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -613,7 +613,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -613,7 +613,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
...@@ -629,8 +629,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -629,8 +629,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
**kwargs: object, **kwargs: object,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None: if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
multimodal_embeddings, multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index, is_multimodal=input_ids == self.config.image_token_index,
......
...@@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -309,7 +309,7 @@ class BaiChuanModel(nn.Module): ...@@ -309,7 +309,7 @@ class BaiChuanModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -323,7 +323,7 @@ class BaiChuanModel(nn.Module): ...@@ -323,7 +323,7 @@ class BaiChuanModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -426,8 +426,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant ...@@ -426,8 +426,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -438,7 +438,7 @@ class BailingMoeModel(nn.Module): ...@@ -438,7 +438,7 @@ class BailingMoeModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
def forward( def forward(
...@@ -452,7 +452,7 @@ class BailingMoeModel(nn.Module): ...@@ -452,7 +452,7 @@ class BailingMoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -608,8 +608,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -608,8 +608,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -314,7 +314,7 @@ class BambaModel(nn.Module): ...@@ -314,7 +314,7 @@ class BambaModel(nn.Module):
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -328,7 +328,7 @@ class BambaModel(nn.Module): ...@@ -328,7 +328,7 @@ class BambaModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -493,8 +493,8 @@ class BambaForCausalLM( ...@@ -493,8 +493,8 @@ class BambaForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -375,7 +375,7 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -375,7 +375,7 @@ class BertModel(nn.Module, SupportsQuant):
self.embeddings = embedding_class(self.config) self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.word_embeddings(input_ids) return self.embeddings.word_embeddings(input_ids)
def forward( def forward(
...@@ -486,8 +486,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ...@@ -486,8 +486,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
) )
self.pooler = self._build_pooler(pooler_config) self.pooler = self._build_pooler(pooler_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
...@@ -835,8 +835,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu ...@@ -835,8 +835,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids) return self.bert.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
...@@ -893,8 +893,8 @@ class BertForTokenClassification(nn.Module): ...@@ -893,8 +893,8 @@ class BertForTokenClassification(nn.Module):
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids) return self.bert.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -463,7 +463,7 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -463,7 +463,7 @@ class BertWithRope(nn.Module, SupportsQuant):
) )
self.pooler = BertPooler(self.config) if add_pooling_layer else None self.pooler = BertPooler(self.config) if add_pooling_layer else None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings(input_ids)
def forward( def forward(
...@@ -714,8 +714,8 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -714,8 +714,8 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
loaded_params = loader.load_weights(weights) loaded_params = loader.load_weights(weights)
return loaded_params return loaded_params
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids) return self.new.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -630,7 +630,7 @@ class Blip2ForConditionalGeneration( ...@@ -630,7 +630,7 @@ class Blip2ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -271,7 +271,7 @@ class BloomModel(nn.Module): ...@@ -271,7 +271,7 @@ class BloomModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
def forward( def forward(
...@@ -285,7 +285,7 @@ class BloomModel(nn.Module): ...@@ -285,7 +285,7 @@ class BloomModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states) hidden_states = self.word_embeddings_layernorm(hidden_states)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -353,8 +353,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -353,8 +353,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module): ...@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
...@@ -912,7 +912,7 @@ class ChameleonModel(nn.Module): ...@@ -912,7 +912,7 @@ class ChameleonModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -998,7 +998,7 @@ class ChameleonForConditionalGeneration( ...@@ -998,7 +998,7 @@ class ChameleonForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
...@@ -1006,7 +1006,7 @@ class ChameleonForConditionalGeneration( ...@@ -1006,7 +1006,7 @@ class ChameleonForConditionalGeneration(
image_tokens = self.model.get_image_tokens( image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.dtype) image_input["data"].to(self.config.dtype)
) )
vision_embeddings = self.model.get_input_embeddings(image_tokens) vision_embeddings = self.model.embed_input_ids(image_tokens)
return vision_embeddings return vision_embeddings
def forward( def forward(
......
...@@ -353,7 +353,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): ...@@ -353,7 +353,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
self.encoder.make_empty_intermediate_tensors self.encoder.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
def forward( def forward(
...@@ -368,7 +368,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): ...@@ -368,7 +368,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -451,8 +451,8 @@ class ChatGLMBaseModel(nn.Module): ...@@ -451,8 +451,8 @@ class ChatGLMBaseModel(nn.Module):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def compute_logits( def compute_logits(
self, self,
......
...@@ -561,7 +561,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -561,7 +561,7 @@ class CLIPTextTransformer(nn.Module):
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.token_embedding(input_ids) return self.embeddings.token_embedding(input_ids)
def forward( def forward(
...@@ -842,7 +842,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -842,7 +842,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
} }
) )
# Assumes that self.forward is called after self.get_input_embeddings # Assumes that self.forward is called after self.embed_input_ids
self._is_text_input = True self._is_text_input = True
def get_text_features( def get_text_features(
...@@ -903,7 +903,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -903,7 +903,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.text_model return self.text_model
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -917,16 +917,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -917,16 +917,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -439,7 +439,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo ...@@ -439,7 +439,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -311,7 +311,7 @@ class CohereModel(nn.Module): ...@@ -311,7 +311,7 @@ class CohereModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -325,7 +325,7 @@ class CohereModel(nn.Module): ...@@ -325,7 +325,7 @@ class CohereModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -436,8 +436,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -436,8 +436,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
......
...@@ -354,7 +354,7 @@ class DbrxModel(nn.Module): ...@@ -354,7 +354,7 @@ class DbrxModel(nn.Module):
["hidden_states"], config.d_model ["hidden_states"], config.d_model
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -368,7 +368,7 @@ class DbrxModel(nn.Module): ...@@ -368,7 +368,7 @@ class DbrxModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -455,8 +455,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -455,8 +455,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment