Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -73,7 +73,7 @@ class DeepseekV2Model(nn.Module): ...@@ -73,7 +73,7 @@ class DeepseekV2Model(nn.Module):
self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -222,8 +222,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -222,8 +222,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.num_moe_layers = self.config.num_hidden_layers self.num_moe_layers = self.config.num_hidden_layers
self.set_moe_parameters() self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -142,7 +142,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -142,7 +142,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -206,8 +206,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): ...@@ -206,8 +206,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe) self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -557,9 +557,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -557,9 +557,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -1236,7 +1236,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1236,7 +1236,7 @@ class DeepseekV2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -1250,7 +1250,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1250,7 +1250,7 @@ class DeepseekV2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -1389,8 +1389,8 @@ class DeepseekV2ForCausalLM( ...@@ -1389,8 +1389,8 @@ class DeepseekV2ForCausalLM(
self.extract_moe_parameters(example_moe) self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -619,7 +619,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -619,7 +619,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -398,7 +398,7 @@ class Dots1Model(nn.Module): ...@@ -398,7 +398,7 @@ class Dots1Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -412,7 +412,7 @@ class Dots1Model(nn.Module): ...@@ -412,7 +412,7 @@ class Dots1Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -541,8 +541,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -541,8 +541,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -840,7 +840,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -840,7 +840,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
...@@ -858,8 +858,8 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -858,8 +858,8 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
vision_embeddings, vision_embeddings,
is_multimodal=input_ids == self.config.image_token_id, is_multimodal=input_ids == self.config.image_token_id,
......
...@@ -465,7 +465,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -465,7 +465,7 @@ class Ernie4_5_MoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -479,7 +479,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -479,7 +479,7 @@ class Ernie4_5_MoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -726,8 +726,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -726,8 +726,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1656,9 +1656,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1656,9 +1656,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
return modalities return modalities
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return None return None
...@@ -1681,7 +1679,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1681,7 +1679,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1694,9 +1692,9 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1694,9 +1692,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
......
...@@ -561,7 +561,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -561,7 +561,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -577,7 +577,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -577,7 +577,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -642,8 +642,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -642,8 +642,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -112,7 +112,7 @@ class ErnieMultiTokenPredictor(nn.Module): ...@@ -112,7 +112,7 @@ class ErnieMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -160,8 +160,8 @@ class ErnieMTP(nn.Module, SupportsPP): ...@@ -160,8 +160,8 @@ class ErnieMTP(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -357,7 +357,7 @@ class ExaoneModel(nn.Module): ...@@ -357,7 +357,7 @@ class ExaoneModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -371,7 +371,7 @@ class ExaoneModel(nn.Module): ...@@ -371,7 +371,7 @@ class ExaoneModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -512,8 +512,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -512,8 +512,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -344,7 +344,7 @@ class Exaone4Model(nn.Module): ...@@ -344,7 +344,7 @@ class Exaone4Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -358,7 +358,7 @@ class Exaone4Model(nn.Module): ...@@ -358,7 +358,7 @@ class Exaone4Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -498,8 +498,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -498,8 +498,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -399,7 +399,7 @@ class FalconModel(nn.Module): ...@@ -399,7 +399,7 @@ class FalconModel(nn.Module):
["hidden_states"], config.hidden_size ["hidden_states"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
def forward( def forward(
...@@ -413,7 +413,7 @@ class FalconModel(nn.Module): ...@@ -413,7 +413,7 @@ class FalconModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for layer in islice(self.h, self.start_layer, self.end_layer): for layer in islice(self.h, self.start_layer, self.end_layer):
...@@ -515,8 +515,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -515,8 +515,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -461,7 +461,7 @@ class FalconH1Model(nn.Module): ...@@ -461,7 +461,7 @@ class FalconH1Model(nn.Module):
else: else:
self.final_layernorm = PPMissingLayer() self.final_layernorm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -476,7 +476,7 @@ class FalconH1Model(nn.Module): ...@@ -476,7 +476,7 @@ class FalconH1Model(nn.Module):
hidden_states = inputs_embeds * self.embedding_multiplier hidden_states = inputs_embeds * self.embedding_multiplier
else: else:
hidden_states = ( hidden_states = (
self.get_input_embeddings(input_ids) * self.embedding_multiplier self.embed_input_ids(input_ids) * self.embedding_multiplier
) )
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -601,8 +601,8 @@ class FalconH1ForCausalLM( ...@@ -601,8 +601,8 @@ class FalconH1ForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -333,7 +333,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -333,7 +333,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -293,7 +293,7 @@ class GemmaModel(nn.Module): ...@@ -293,7 +293,7 @@ class GemmaModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -307,7 +307,7 @@ class GemmaModel(nn.Module): ...@@ -307,7 +307,7 @@ class GemmaModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer hidden_states *= self.normalizer
residual = None residual = None
else: else:
...@@ -396,8 +396,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -396,8 +396,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module): ...@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -304,7 +304,7 @@ class Gemma2Model(nn.Module): ...@@ -304,7 +304,7 @@ class Gemma2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
hidden_states *= self.normalizer hidden_states *= self.normalizer
residual = None residual = None
else: else:
...@@ -409,8 +409,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -409,8 +409,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -393,7 +393,7 @@ class Gemma3Model(nn.Module): ...@@ -393,7 +393,7 @@ class Gemma3Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
# NOTE(woosuk): Only apply the normalizer to the output of # NOTE(woosuk): Only apply the normalizer to the output of
# vocab embedding. Don't apply it to the vision embedding. # vocab embedding. Don't apply it to the vision embedding.
return self.embed_tokens(input_ids) * self.normalizer return self.embed_tokens(input_ids) * self.normalizer
...@@ -410,7 +410,7 @@ class Gemma3Model(nn.Module): ...@@ -410,7 +410,7 @@ class Gemma3Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -540,8 +540,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -540,8 +540,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration( ...@@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment