Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
......@@ -56,13 +56,13 @@ The initialization code should look like this:
### Computation Code
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
```python
class MyModel(nn.Module):
...
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
...
```
......
......@@ -36,7 +36,7 @@ Further update the model as follows:
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
??? code
......@@ -49,7 +49,7 @@ Further update the model as follows:
image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(
def embed_multimodal(
self,
**kwargs: object,
) -> MultiModalEmbeddings | None:
......@@ -69,7 +69,7 @@ Further update the model as follows:
!!! note
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].
You may override this method if additional logic is required for your model when merging embeddings.
......
......@@ -382,7 +382,7 @@ class ApertusModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -396,7 +396,7 @@ class ApertusModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -557,8 +557,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -239,7 +239,7 @@ class ArceeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -254,7 +254,7 @@ class ArceeModel(nn.Module):
hidden_states = (
inputs_embeds
if inputs_embeds is not None
else self.get_input_embeddings(input_ids)
else self.embed_input_ids(input_ids)
)
residual = None
else:
......@@ -423,8 +423,8 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights into the model (delegates to inner model and handles
......
......@@ -442,7 +442,7 @@ class ArcticModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -456,7 +456,7 @@ class ArcticModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -496,8 +496,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -613,7 +613,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......@@ -629,8 +629,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
multimodal_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
......
......@@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -309,7 +309,7 @@ class BaiChuanModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -323,7 +323,7 @@ class BaiChuanModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -426,8 +426,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -438,7 +438,7 @@ class BailingMoeModel(nn.Module):
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
def forward(
......@@ -452,7 +452,7 @@ class BailingMoeModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -608,8 +608,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -314,7 +314,7 @@ class BambaModel(nn.Module):
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -328,7 +328,7 @@ class BambaModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -493,8 +493,8 @@ class BambaForCausalLM(
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -375,7 +375,7 @@ class BertModel(nn.Module, SupportsQuant):
self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.word_embeddings(input_ids)
def forward(
......@@ -486,8 +486,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
)
self.pooler = self._build_pooler(pooler_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......@@ -835,8 +835,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......@@ -893,8 +893,8 @@ class BertForTokenClassification(nn.Module):
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......
......@@ -463,7 +463,7 @@ class BertWithRope(nn.Module, SupportsQuant):
)
self.pooler = BertPooler(self.config) if add_pooling_layer else None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
def forward(
......@@ -714,8 +714,8 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
loaded_params = loader.load_weights(weights)
return loaded_params
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.new.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -630,7 +630,7 @@ class Blip2ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -271,7 +271,7 @@ class BloomModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
def forward(
......@@ -285,7 +285,7 @@ class BloomModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
else:
assert intermediate_tensors is not None
......@@ -353,8 +353,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
......@@ -912,7 +912,7 @@ class ChameleonModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -998,7 +998,7 @@ class ChameleonForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......@@ -1006,7 +1006,7 @@ class ChameleonForConditionalGeneration(
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.dtype)
)
vision_embeddings = self.model.get_input_embeddings(image_tokens)
vision_embeddings = self.model.embed_input_ids(image_tokens)
return vision_embeddings
def forward(
......
......@@ -353,7 +353,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
self.encoder.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
def forward(
......@@ -368,7 +368,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
......@@ -451,8 +451,8 @@ class ChatGLMBaseModel(nn.Module):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def compute_logits(
self,
......
......@@ -561,7 +561,7 @@ class CLIPTextTransformer(nn.Module):
eps=config.layer_norm_eps,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.token_embedding(input_ids)
def forward(
......@@ -842,7 +842,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
}
)
# Assumes that self.forward is called after self.get_input_embeddings
# Assumes that self.forward is called after self.embed_input_ids
self._is_text_input = True
def get_text_features(
......@@ -903,7 +903,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -917,16 +917,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -439,7 +439,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -311,7 +311,7 @@ class CohereModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -325,7 +325,7 @@ class CohereModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -436,8 +436,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@torch.no_grad()
def forward(
......
......@@ -354,7 +354,7 @@ class DbrxModel(nn.Module):
["hidden_states"], config.d_model
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
......@@ -368,7 +368,7 @@ class DbrxModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
......@@ -455,8 +455,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment