Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
......@@ -685,7 +685,7 @@ class Gemma3nSelfDecoder(nn.Module):
per_layer_inputs = per_layer_projection
return per_layer_inputs
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
......@@ -712,7 +712,7 @@ class Gemma3nSelfDecoder(nn.Module):
if inputs_embeds is not None:
hidden_states_0 = inputs_embeds
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
hidden_states_0 = self.embed_input_ids(input_ids)
adjusted_per_layer_inputs = self.get_per_layer_inputs(
hidden_states_0, per_layer_inputs
......@@ -881,8 +881,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.embed_input_ids(input_ids)
def fast_prefill_forward(
self,
......@@ -1125,8 +1125,8 @@ class Gemma3nForCausalLM(nn.Module):
config.vocab_size, soft_cap=config.final_logit_softcapping
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -645,7 +645,7 @@ class Gemma3nForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if mm_input_by_modality is None:
return []
......@@ -664,7 +664,7 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings.extend(audio_embeddings)
return multimodal_embeddings
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
......@@ -689,9 +689,9 @@ class Gemma3nForConditionalGeneration(
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......@@ -709,10 +709,10 @@ class Gemma3nForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE (NickLucche) During profiling, `get_input_embeddings` is not
# NOTE (NickLucche) During profiling, `embed_input_ids` is not
# called, hence we don't have input_ids to compute PLEs. We simply
# select a chunk of pre-allocated PLEs. During normal execution,
# `get_input_embeddings` is called before forward, hence this slice
# `embed_input_ids` is called before forward, hence this slice
# will contain PLEs computed from the actual input_ids.
per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]
......
......@@ -275,8 +275,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -1594,9 +1594,7 @@ class Glm4vForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
......
......@@ -455,7 +455,7 @@ class Glm4MoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -469,7 +469,7 @@ class Glm4MoeModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -704,8 +704,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -149,7 +149,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -211,8 +211,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -756,9 +756,9 @@ class GLM4VForCausalLM(
def get_language_model(self) -> torch.nn.Module:
return self.transformer
get_input_embeddings = SupportsMultiModal.get_input_embeddings
embed_input_ids = SupportsMultiModal.embed_input_ids
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
......
......@@ -213,7 +213,7 @@ class GPT2Model(nn.Module):
["hidden_states"], config.n_embd
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
......@@ -225,7 +225,7 @@ class GPT2Model(nn.Module):
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
......@@ -293,8 +293,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......@@ -365,8 +365,8 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
}
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
......
......@@ -230,7 +230,7 @@ class GPTBigCodeModel(nn.Module):
["hidden_states"], config.n_embd
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
......@@ -242,7 +242,7 @@ class GPTBigCodeModel(nn.Module):
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = inputs_embeds + self.wpe(position_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
......@@ -306,8 +306,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -215,7 +215,7 @@ class GPTJModel(nn.Module):
["hidden_states"], config.n_embd
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward(
......@@ -229,7 +229,7 @@ class GPTJModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.h, self.start_layer, self.end_layer):
......@@ -319,8 +319,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -229,7 +229,7 @@ class GPTNeoXModel(nn.Module):
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids)
def forward(
......@@ -243,7 +243,7 @@ class GPTNeoXModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
......@@ -317,8 +317,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self.gpt_neox.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -269,7 +269,7 @@ class GptOssModel(nn.Module):
)
self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
def forward(
......@@ -283,7 +283,7 @@ class GptOssModel(nn.Module):
if inputs_embeds is not None:
x = inputs_embeds
else:
x = self.get_input_embeddings(input_ids)
x = self.embed_input_ids(input_ids)
residual = None
else:
......@@ -703,8 +703,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -318,7 +318,7 @@ class GraniteModel(nn.Module):
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -332,7 +332,7 @@ class GraniteModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.config.embedding_multiplier
else:
......@@ -473,8 +473,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -767,7 +767,7 @@ class GraniteSpeechForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
def embed_multimodal(
self,
**kwargs: object,
) -> MultiModalEmbeddings:
......@@ -779,7 +779,7 @@ class GraniteSpeechForConditionalGeneration(
audio_features = self._process_audio_input(audio_input)
return audio_features
def get_input_embeddings(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
......@@ -790,9 +790,9 @@ class GraniteSpeechForConditionalGeneration(
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids)
return super().embed_input_ids(input_ids)
return super().get_input_embeddings(
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
......
......@@ -315,7 +315,7 @@ class GraniteMoeModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -329,7 +329,7 @@ class GraniteMoeModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.embedding_multiplier
else:
assert intermediate_tensors is not None
......@@ -531,8 +531,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 / self.config.logits_scaling,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -366,7 +366,7 @@ class GraniteMoeHybridModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -380,7 +380,7 @@ class GraniteMoeHybridModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states * self.embedding_multiplier
residual = None
else:
......@@ -680,8 +680,8 @@ class GraniteMoeHybridForCausalLM(
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -182,7 +182,7 @@ class GraniteMoeSharedModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -196,7 +196,7 @@ class GraniteMoeSharedModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.embedding_multiplier
else:
assert intermediate_tensors is not None
......@@ -295,8 +295,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 / self.config.logits_scaling,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -334,7 +334,7 @@ class Grok1Model(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = hidden_states * self.embedding_multiplier_scale
return hidden_states
......@@ -350,7 +350,7 @@ class Grok1Model(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -522,8 +522,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
......
......@@ -643,7 +643,7 @@ class HunYuanModel(nn.Module):
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
......@@ -657,7 +657,7 @@ class HunYuanModel(nn.Module):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
......@@ -987,8 +987,8 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
)
return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
......
......@@ -732,7 +732,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
def embed_multimodal(
self,
**kwargs: object,
) -> MultiModalEmbeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment