Unverified Commit 97d1c993 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Rename clashing method names for vLLM model protocol (#27583)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 32262834
...@@ -1180,7 +1180,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1180,7 +1180,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
) )
return image_embeds return image_embeds
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
......
...@@ -482,7 +482,7 @@ class PhiMoEModel(nn.Module): ...@@ -482,7 +482,7 @@ class PhiMoEModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -496,7 +496,7 @@ class PhiMoEModel(nn.Module): ...@@ -496,7 +496,7 @@ class PhiMoEModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -648,8 +648,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -648,8 +648,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -461,7 +461,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -461,7 +461,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
...@@ -762,7 +762,7 @@ class Plamo2Model(torch.nn.Module): ...@@ -762,7 +762,7 @@ class Plamo2Model(torch.nn.Module):
self.layers = Plamo2Decoder(vllm_config=vllm_config, prefix=f"{prefix}.layers") self.layers = Plamo2Decoder(vllm_config=vllm_config, prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -776,7 +776,7 @@ class Plamo2Model(torch.nn.Module): ...@@ -776,7 +776,7 @@ class Plamo2Model(torch.nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -839,8 +839,8 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): ...@@ -839,8 +839,8 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -221,7 +221,7 @@ class QWenModel(nn.Module): ...@@ -221,7 +221,7 @@ class QWenModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids) return self.wte(input_ids)
def forward( def forward(
...@@ -235,7 +235,7 @@ class QWenModel(nn.Module): ...@@ -235,7 +235,7 @@ class QWenModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
......
...@@ -355,7 +355,7 @@ class Qwen2Model(nn.Module): ...@@ -355,7 +355,7 @@ class Qwen2Model(nn.Module):
self.aux_hidden_state_layers = tuple[int, ...]() self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -369,7 +369,7 @@ class Qwen2Model(nn.Module): ...@@ -369,7 +369,7 @@ class Qwen2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -504,8 +504,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -504,8 +504,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers self.model.aux_hidden_state_layers = layers
......
...@@ -1132,7 +1132,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1132,7 +1132,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return [] return []
...@@ -1158,7 +1158,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1158,7 +1158,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
# TODO (ywang96): support overlapping modality embeddings so that # TODO (ywang96): support overlapping modality embeddings so that
# `use_audio_in_video` will work on V1. # `use_audio_in_video` will work on V1.
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1168,16 +1168,16 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1168,16 +1168,16 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().get_input_embeddings(input_ids) return super().embed_input_ids(input_ids)
return super().get_input_embeddings( return super().embed_input_ids(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
def get_multimodal_embeddings_v0(self, **kwargs: object) -> NestedTensors | None: def embed_multimodal_v0(self, **kwargs: object) -> NestedTensors | None:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
......
...@@ -1534,7 +1534,7 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1534,7 +1534,7 @@ class Qwen2_5_VLForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return [] return []
......
...@@ -439,7 +439,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -439,7 +439,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return [] return []
......
...@@ -389,7 +389,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -389,7 +389,7 @@ class Qwen2MoeModel(nn.Module):
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -403,7 +403,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -403,7 +403,7 @@ class Qwen2MoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -566,8 +566,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -566,8 +566,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -73,8 +73,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -73,8 +73,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -1507,7 +1507,7 @@ class Qwen2VLForConditionalGeneration( ...@@ -1507,7 +1507,7 @@ class Qwen2VLForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return [] return []
......
...@@ -306,8 +306,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): ...@@ -306,8 +306,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
num_layers = len(self.model.layers) num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -427,7 +427,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -427,7 +427,7 @@ class Qwen3MoeModel(nn.Module):
# Track layers for auxiliary hidden state outputs (EAGLE3) # Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = () self.aux_hidden_state_layers: tuple[int, ...] = ()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -441,7 +441,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -441,7 +441,7 @@ class Qwen3MoeModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -714,8 +714,8 @@ class Qwen3MoeForCausalLM( ...@@ -714,8 +714,8 @@ class Qwen3MoeForCausalLM(
num_layers = len(self.model.layers) num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -998,7 +998,7 @@ class Qwen3NextModel(nn.Module): ...@@ -998,7 +998,7 @@ class Qwen3NextModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -1012,7 +1012,7 @@ class Qwen3NextModel(nn.Module): ...@@ -1012,7 +1012,7 @@ class Qwen3NextModel(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -1217,8 +1217,8 @@ class Qwen3NextForCausalLM( ...@@ -1217,8 +1217,8 @@ class Qwen3NextForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.set_moe_parameters() self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -93,7 +93,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -93,7 +93,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def forward( def forward(
...@@ -107,7 +107,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -107,7 +107,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1] assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states) hidden_states = self.pre_fc_norm_hidden(hidden_states)
...@@ -257,8 +257,8 @@ class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts): ...@@ -257,8 +257,8 @@ class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts):
) )
self.set_moe_parameters() self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.embed_input_ids(input_ids)
def forward( def forward(
self, self,
......
...@@ -613,7 +613,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -613,7 +613,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -1252,9 +1252,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1252,9 +1252,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return [] return []
...@@ -1278,7 +1276,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1278,7 +1276,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_embeddings += tuple(audio_embeddings) multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1286,9 +1284,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1286,9 +1284,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False, handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.get_input_embeddings, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
......
...@@ -1100,7 +1100,7 @@ class Qwen3LLMModel(Qwen3Model): ...@@ -1100,7 +1100,7 @@ class Qwen3LLMModel(Qwen3Model):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -1493,9 +1493,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1493,9 +1493,7 @@ class Qwen3VLForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
self, **kwargs: object
) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return None return None
...@@ -1557,7 +1555,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1557,7 +1555,7 @@ class Qwen3VLForConditionalGeneration(
return deepstack_input_embeds, multimodal_embeddings return deepstack_input_embeds, multimodal_embeddings
def get_input_embeddings( def embed_input_ids(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
...@@ -1565,9 +1563,9 @@ class Qwen3VLForConditionalGeneration( ...@@ -1565,9 +1563,9 @@ class Qwen3VLForConditionalGeneration(
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False, handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.get_input_embeddings, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token, handle_oov_mm_token=handle_oov_mm_token,
) )
...@@ -1577,7 +1575,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1577,7 +1575,7 @@ class Qwen3VLForConditionalGeneration(
if is_multimodal is None: if is_multimodal is None:
raise ValueError( raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, " "`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to " "please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229." "https://github.com/vllm-project/vllm/pull/16229."
) )
......
...@@ -97,7 +97,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): ...@@ -97,7 +97,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.embed_input_ids(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
......
...@@ -777,7 +777,7 @@ class QwenVLForConditionalGeneration( ...@@ -777,7 +777,7 @@ class QwenVLForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.transformer return self.transformer
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return [] return []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment