Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -866,7 +866,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -866,7 +866,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return [] return []
......
...@@ -498,7 +498,7 @@ class FlashModel(nn.Module): ...@@ -498,7 +498,7 @@ class FlashModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -512,7 +512,7 @@ class FlashModel(nn.Module): ...@@ -512,7 +512,7 @@ class FlashModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -583,8 +583,8 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -583,8 +583,8 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -135,7 +135,7 @@ class MambaModel(nn.Module): ...@@ -135,7 +135,7 @@ class MambaModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings(input_ids)
def forward( def forward(
...@@ -149,7 +149,7 @@ class MambaModel(nn.Module): ...@@ -149,7 +149,7 @@ class MambaModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -218,8 +218,8 @@ class MambaForCausalLM( ...@@ -218,8 +218,8 @@ class MambaForCausalLM(
self.backbone.make_empty_intermediate_tensors self.backbone.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids) return self.backbone.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -131,7 +131,7 @@ class Mamba2Model(nn.Module): ...@@ -131,7 +131,7 @@ class Mamba2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids) return self.embeddings(input_ids)
def forward( def forward(
...@@ -145,7 +145,7 @@ class Mamba2Model(nn.Module): ...@@ -145,7 +145,7 @@ class Mamba2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -257,8 +257,8 @@ class Mamba2ForCausalLM( ...@@ -257,8 +257,8 @@ class Mamba2ForCausalLM(
self.backbone.make_empty_intermediate_tensors self.backbone.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids) return self.backbone.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -791,7 +791,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -791,7 +791,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.decoder return self.decoder
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
......
...@@ -70,7 +70,7 @@ class MiMoModel(Qwen2Model): ...@@ -70,7 +70,7 @@ class MiMoModel(Qwen2Model):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
......
...@@ -120,7 +120,7 @@ class MiMoMultiTokenPredictor(nn.Module): ...@@ -120,7 +120,7 @@ class MiMoMultiTokenPredictor(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -164,8 +164,8 @@ class MiMoMTP(nn.Module): ...@@ -164,8 +164,8 @@ class MiMoMTP(nn.Module):
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -440,7 +440,7 @@ class MiniCPMModel(nn.Module): ...@@ -440,7 +440,7 @@ class MiniCPMModel(nn.Module):
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb return embedding * self.config.scale_emb
...@@ -455,7 +455,7 @@ class MiniCPMModel(nn.Module): ...@@ -455,7 +455,7 @@ class MiniCPMModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -615,8 +615,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -615,8 +615,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers self.model.aux_hidden_state_layers = layers
......
...@@ -193,7 +193,7 @@ class EagleMiniCPMModel(nn.Module): ...@@ -193,7 +193,7 @@ class EagleMiniCPMModel(nn.Module):
] ]
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb return embedding * self.config.scale_emb
...@@ -203,7 +203,7 @@ class EagleMiniCPMModel(nn.Module): ...@@ -203,7 +203,7 @@ class EagleMiniCPMModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
input_embeds = self.get_input_embeddings(input_ids) input_embeds = self.embed_input_ids(input_ids)
input_embeds = self.input_norm1(input_embeds) input_embeds = self.input_norm1(input_embeds)
hidden_states = self.input_norm2(hidden_states) hidden_states = self.input_norm2(hidden_states)
...@@ -354,8 +354,8 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -354,8 +354,8 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
vllm_config=vllm_config, prefix=prefix, start_layer=start_layer vllm_config=vllm_config, prefix=prefix, start_layer=start_layer
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1139,7 +1139,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1139,7 +1139,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.llm return self.llm
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
......
...@@ -360,7 +360,7 @@ class MiniMaxM2Model(nn.Module): ...@@ -360,7 +360,7 @@ class MiniMaxM2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -374,7 +374,7 @@ class MiniMaxM2Model(nn.Module): ...@@ -374,7 +374,7 @@ class MiniMaxM2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -510,8 +510,8 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsPP): ...@@ -510,8 +510,8 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -620,7 +620,7 @@ class MiniMaxText01Model(nn.Module): ...@@ -620,7 +620,7 @@ class MiniMaxText01Model(nn.Module):
) )
minimax_cache_tensors[:, slots_tensor, ...] = 0 minimax_cache_tensors[:, slots_tensor, ...] = 0
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -709,8 +709,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -709,8 +709,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -353,7 +353,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -353,7 +353,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
...@@ -371,8 +371,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -371,8 +371,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.get_input_embeddings( inputs_embeds = self.embed_input_ids(
input_ids, input_ids,
vision_embeddings, vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index, is_multimodal=input_ids == self.config.image_token_index,
......
...@@ -549,7 +549,7 @@ class Mistral3ForConditionalGeneration( ...@@ -549,7 +549,7 @@ class Mistral3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -345,7 +345,7 @@ class MixtralModel(nn.Module): ...@@ -345,7 +345,7 @@ class MixtralModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -359,7 +359,7 @@ class MixtralModel(nn.Module): ...@@ -359,7 +359,7 @@ class MixtralModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -591,8 +591,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -591,8 +591,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -865,7 +865,7 @@ class Llama4ForConditionalGeneration( ...@@ -865,7 +865,7 @@ class Llama4ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -46,7 +46,7 @@ class ModernBertEmbeddings(nn.Module): ...@@ -46,7 +46,7 @@ class ModernBertEmbeddings(nn.Module):
) )
self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias) self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.tok_embeddings(input_ids) return self.tok_embeddings(input_ids)
def forward( def forward(
...@@ -225,8 +225,8 @@ class ModernBertModel(nn.Module): ...@@ -225,8 +225,8 @@ class ModernBertModel(nn.Module):
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings.get_input_embeddings(input_ids) return self.embeddings.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
...@@ -337,8 +337,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -337,8 +337,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = [] self_weights = []
...@@ -424,8 +424,8 @@ class ModernBertForTokenClassification(nn.Module): ...@@ -424,8 +424,8 @@ class ModernBertForTokenClassification(nn.Module):
} }
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self, skip_prefixes=["drop"]) loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
......
...@@ -832,7 +832,7 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -832,7 +832,7 @@ class MolmoModel(nn.Module, SupportsQuant):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -1491,7 +1491,7 @@ class MolmoForCausalLM( ...@@ -1491,7 +1491,7 @@ class MolmoForCausalLM(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -248,7 +248,7 @@ class MPTModel(nn.Module): ...@@ -248,7 +248,7 @@ class MPTModel(nn.Module):
["hidden_states"], config.d_model ["hidden_states"], config.d_model
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -262,7 +262,7 @@ class MPTModel(nn.Module): ...@@ -262,7 +262,7 @@ class MPTModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -308,8 +308,8 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -308,8 +308,8 @@ class MPTForCausalLM(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors self.transformer.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -655,7 +655,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -655,7 +655,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
The replacement returned is not actually used to replace the placeholder The replacement returned is not actually used to replace the placeholder
tokens - it's just used to make sure we allocate the correct number tokens - it's just used to make sure we allocate the correct number
of tokens. of tokens.
Actual replacement is done in get_multimodal_embeddings of Actual replacement is done in embed_multimodal of
NemotronH_Nano_VL_V2 NemotronH_Nano_VL_V2
(specifically in _process_video_input -> _create_final_video_embeddings). (specifically in _process_video_input -> _create_final_video_embeddings).
There, we create the final embeddings with text embeddings for indicator tokens There, we create the final embeddings with text embeddings for indicator tokens
...@@ -1401,7 +1401,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1401,7 +1401,7 @@ class NemotronH_Nano_VL_V2(
# Create final video embeddings, merging text embeddings for indicator # Create final video embeddings, merging text embeddings for indicator
# tokens with video embeddings # tokens with video embeddings
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids) text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)
final_video_embeddings = _merge_multimodal_embeddings( final_video_embeddings = _merge_multimodal_embeddings(
inputs_embeds=text_embeddings, inputs_embeds=text_embeddings,
multimodal_embeddings=video_embeddings, multimodal_embeddings=video_embeddings,
...@@ -1465,7 +1465,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1465,7 +1465,7 @@ class NemotronH_Nano_VL_V2(
return modalities return modalities
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Validate the multimodal input keyword arguments # Validate the multimodal input keyword arguments
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities is None: if modalities is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment