Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -685,7 +685,7 @@ class Gemma3nSelfDecoder(nn.Module): ...@@ -685,7 +685,7 @@ class Gemma3nSelfDecoder(nn.Module):
per_layer_inputs = per_layer_projection per_layer_inputs = per_layer_projection
return per_layer_inputs return per_layer_inputs
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.embed_scale return self.embed_tokens(input_ids) * self.embed_scale
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
...@@ -712,7 +712,7 @@ class Gemma3nSelfDecoder(nn.Module): ...@@ -712,7 +712,7 @@ class Gemma3nSelfDecoder(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states_0 = inputs_embeds hidden_states_0 = inputs_embeds
else: else:
hidden_states_0 = self.get_input_embeddings(input_ids) hidden_states_0 = self.embed_input_ids(input_ids)
adjusted_per_layer_inputs = self.get_per_layer_inputs( adjusted_per_layer_inputs = self.get_per_layer_inputs(
hidden_states_0, per_layer_inputs hidden_states_0, per_layer_inputs
...@@ -881,8 +881,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -881,8 +881,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_per_layer_input_embeddings(input_ids) return self.self_decoder.get_per_layer_input_embeddings(input_ids)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.get_input_embeddings(input_ids) return self.self_decoder.embed_input_ids(input_ids)
def fast_prefill_forward( def fast_prefill_forward(
self, self,
...@@ -1125,8 +1125,8 @@ class Gemma3nForCausalLM(nn.Module): ...@@ -1125,8 +1125,8 @@ class Gemma3nForCausalLM(nn.Module):
config.vocab_size, soft_cap=config.final_logit_softcapping config.vocab_size, soft_cap=config.final_logit_softcapping
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -645,7 +645,7 @@ class Gemma3nForConditionalGeneration( ...@@ -645,7 +645,7 @@ class Gemma3nForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if mm_input_by_modality is None: if mm_input_by_modality is None:
return [] return []
...@@ -664,7 +664,7 @@ class Gemma3nForConditionalGeneration( ...@@ -664,7 +664,7 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings.extend(audio_embeddings) multimodal_embeddings.extend(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
...@@ -689,9 +689,9 @@ class Gemma3nForConditionalGeneration( ...@@ -689,9 +689,9 @@ class Gemma3nForConditionalGeneration(
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
...@@ -709,10 +709,10 @@ class Gemma3nForConditionalGeneration( ...@@ -709,10 +709,10 @@ class Gemma3nForConditionalGeneration(
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
# NOTE (NickLucche) During profiling, `get_input_embeddings` is not # NOTE (NickLucche) During profiling, `embed_input_ids` is not
# called, hence we don't have input_ids to compute PLEs. We simply # called, hence we don't have input_ids to compute PLEs. We simply
# select a chunk of pre-allocated PLEs. During normal execution, # select a chunk of pre-allocated PLEs. During normal execution,
# `get_input_embeddings` is called before forward, hence this slice # `embed_input_ids` is called before forward, hence this slice
# will contain PLEs computed from the actual input_ids. # will contain PLEs computed from the actual input_ids.
per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]] per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]
......
...@@ -275,8 +275,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -275,8 +275,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1594,9 +1594,7 @@ class Glm4vForConditionalGeneration( ...@@ -1594,9 +1594,7 @@ class Glm4vForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return None return None
......
...@@ -455,7 +455,7 @@ class Glm4MoeModel(nn.Module): ...@@ -455,7 +455,7 @@ class Glm4MoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -469,7 +469,7 @@ class Glm4MoeModel(nn.Module): ...@@ -469,7 +469,7 @@ class Glm4MoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -704,8 +704,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper ...@@ -704,8 +704,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper
self.extract_moe_parameters(example_moe) self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -149,7 +149,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module): ...@@ -149,7 +149,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -211,8 +211,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): ...@@ -211,8 +211,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe) self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -756,9 +756,9 @@ class GLM4VForCausalLM( ...@@ -756,9 +756,9 @@ class GLM4VForCausalLM(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.transformer return self.transformer
get_input_embeddings = SupportsMultiModal.get_input_embeddings embed_input_ids = SupportsMultiModal.embed_input_ids
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -213,7 +213,7 @@ class GPT2Model(nn.Module): ...@@ -213,7 +213,7 @@ class GPT2Model(nn.Module):
["hidden_states"], config.n_embd ["hidden_states"], config.n_embd
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -225,7 +225,7 @@ class GPT2Model(nn.Module): ...@@ -225,7 +225,7 @@ class GPT2Model(nn.Module):
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
else: else:
...@@ -293,8 +293,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -293,8 +293,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
...@@ -365,8 +365,8 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -365,8 +365,8 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -230,7 +230,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -230,7 +230,7 @@ class GPTBigCodeModel(nn.Module):
["hidden_states"], config.n_embd ["hidden_states"], config.n_embd
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -242,7 +242,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -242,7 +242,7 @@ class GPTBigCodeModel(nn.Module):
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
hidden_states = inputs_embeds + self.wpe(position_ids) hidden_states = inputs_embeds + self.wpe(position_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -306,8 +306,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -306,8 +306,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -215,7 +215,7 @@ class GPTJModel(nn.Module): ...@@ -215,7 +215,7 @@ class GPTJModel(nn.Module):
["hidden_states"], config.n_embd ["hidden_states"], config.n_embd
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -229,7 +229,7 @@ class GPTJModel(nn.Module): ...@@ -229,7 +229,7 @@ class GPTJModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.h, self.start_layer, self.end_layer): for layer in islice(self.h, self.start_layer, self.end_layer):
...@@ -319,8 +319,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -319,8 +319,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -229,7 +229,7 @@ class GPTNeoXModel(nn.Module): ...@@ -229,7 +229,7 @@ class GPTNeoXModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids) return self.embed_in(input_ids)
def forward( def forward(
...@@ -243,7 +243,7 @@ class GPTNeoXModel(nn.Module): ...@@ -243,7 +243,7 @@ class GPTNeoXModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.layers, self.start_layer, self.end_layer): for layer in islice(self.layers, self.start_layer, self.end_layer):
...@@ -317,8 +317,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -317,8 +317,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self.gpt_neox.make_empty_intermediate_tensors self.gpt_neox.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids) return self.gpt_neox.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -269,7 +269,7 @@ class GptOssModel(nn.Module): ...@@ -269,7 +269,7 @@ class GptOssModel(nn.Module):
) )
self.aux_hidden_state_layers = tuple[int, ...]() self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
def forward( def forward(
...@@ -283,7 +283,7 @@ class GptOssModel(nn.Module): ...@@ -283,7 +283,7 @@ class GptOssModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
x = inputs_embeds x = inputs_embeds
else: else:
x = self.get_input_embeddings(input_ids) x = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
...@@ -703,8 +703,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): ...@@ -703,8 +703,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
num_layers = len(self.model.layers) num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -318,7 +318,7 @@ class GraniteModel(nn.Module): ...@@ -318,7 +318,7 @@ class GraniteModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -332,7 +332,7 @@ class GraniteModel(nn.Module): ...@@ -332,7 +332,7 @@ class GraniteModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.config.embedding_multiplier hidden_states *= self.config.embedding_multiplier
else: else:
...@@ -473,8 +473,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -473,8 +473,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -767,7 +767,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -767,7 +767,7 @@ class GraniteSpeechForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
) -> MultiModalEmbeddings: ) -> MultiModalEmbeddings:
...@@ -779,7 +779,7 @@ class GraniteSpeechForConditionalGeneration( ...@@ -779,7 +779,7 @@ class GraniteSpeechForConditionalGeneration(
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
return audio_features return audio_features
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -790,9 +790,9 @@ class GraniteSpeechForConditionalGeneration( ...@@ -790,9 +790,9 @@ class GraniteSpeechForConditionalGeneration(
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -315,7 +315,7 @@ class GraniteMoeModel(nn.Module): ...@@ -315,7 +315,7 @@ class GraniteMoeModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -329,7 +329,7 @@ class GraniteMoeModel(nn.Module): ...@@ -329,7 +329,7 @@ class GraniteMoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.embedding_multiplier hidden_states *= self.embedding_multiplier
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -531,8 +531,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -531,8 +531,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 / self.config.logits_scaling, scale=1 / self.config.logits_scaling,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -366,7 +366,7 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -366,7 +366,7 @@ class GraniteMoeHybridModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -380,7 +380,7 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -380,7 +380,7 @@ class GraniteMoeHybridModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states * self.embedding_multiplier hidden_states = hidden_states * self.embedding_multiplier
residual = None residual = None
else: else:
...@@ -680,8 +680,8 @@ class GraniteMoeHybridForCausalLM( ...@@ -680,8 +680,8 @@ class GraniteMoeHybridForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -182,7 +182,7 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -182,7 +182,7 @@ class GraniteMoeSharedModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -196,7 +196,7 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -196,7 +196,7 @@ class GraniteMoeSharedModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.embedding_multiplier hidden_states *= self.embedding_multiplier
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -295,8 +295,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -295,8 +295,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
scale=1 / self.config.logits_scaling, scale=1 / self.config.logits_scaling,
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -334,7 +334,7 @@ class Grok1Model(nn.Module): ...@@ -334,7 +334,7 @@ class Grok1Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
hidden_states = hidden_states * self.embedding_multiplier_scale hidden_states = hidden_states * self.embedding_multiplier_scale
return hidden_states return hidden_states
...@@ -350,7 +350,7 @@ class Grok1Model(nn.Module): ...@@ -350,7 +350,7 @@ class Grok1Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -522,8 +522,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -522,8 +522,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -643,7 +643,7 @@ class HunYuanModel(nn.Module): ...@@ -643,7 +643,7 @@ class HunYuanModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -657,7 +657,7 @@ class HunYuanModel(nn.Module): ...@@ -657,7 +657,7 @@ class HunYuanModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -987,8 +987,8 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): ...@@ -987,8 +987,8 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
) )
return loader.load_weights(weights) return loader.load_weights(weights)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
......
...@@ -732,7 +732,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -732,7 +732,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(
self, self,
**kwargs: object, **kwargs: object,
) -> MultiModalEmbeddings: ) -> MultiModalEmbeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment