Unverified Commit bd2659a5 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

Increase Flexibility for OOV Multimodal Token Handling (#34858)


Signed-off-by: default avatarAlex Brooks <albrooks@redhat.com>
parent 90512b2e
...@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor], embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*, *,
is_multimodal: torch.Tensor | None, is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids( inputs_embeds = super()._embed_text_input_ids(
input_ids, input_ids,
embed_input_ids, embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
# NOTE: inputs_embeds in model runner has size text_config.projection_dim # NOTE: inputs_embeds in model runner has size text_config.projection_dim
...@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
self._is_text_input = ( self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0 multimodal_embeddings is None or len(multimodal_embeddings) == 0
...@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
......
...@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""Embed input IDs with optional multimodal embeddings.""" """Embed input IDs with optional multimodal embeddings."""
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
...@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration( ...@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.decoder.embed_input_ids(input_ids) inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
......
...@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration( ...@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel( self.vision_tower = SiglipVisionModel(
config.vision_config, config.vision_config,
...@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration( ...@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# Early return for text-only inference (no multimodal data) # Early return for text-only inference (no multimodal data)
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration( ...@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration( ...@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds. # them here, as the model forward has only access to the input_embeds.
...@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration( ...@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration( ...@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
self.quant_config = quant_config self.quant_config = quant_config
self.cache_config = cache_config self.cache_config = cache_config
# Check for OOV tokens to see if offsets need to be preserved
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.audio_token_index],
)
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
# The language model is typically a Granite LLM # The language model is typically a Granite LLM
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
...@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration( ...@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration( ...@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol): ...@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
Set internally by `_mark_tower_model`. Set internally by `_mark_tower_model`.
""" """
_has_oov_mm_tokens: bool = False
"""
In general, this should be set at init time by invoking
`configure_mm_token_handling` models & passing all potentially
OOV multimodal tokens.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
""" """
...@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol): ...@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def configure_mm_token_handling(self, vocab_size: int, mm_token_ids: list[int]):
"""Check if any multimodal tokens are out of vocabulary. If so, we will
explicitly mask all multimodal tokens out when computing text embeddings,
since the multimodal embeddings will be scattered over the results.
"""
self._has_oov_mm_tokens = any(tok_id >= vocab_size for tok_id in mm_token_ids)
logger.info(
"Contains out of vocabulary multimodal tokens? %s",
self._has_oov_mm_tokens,
)
def get_language_model(self) -> VllmModel: def get_language_model(self) -> VllmModel:
""" """
Returns the underlying language model used for text generation. Returns the underlying language model used for text generation.
...@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol): ...@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings, multimodal_embeddings: MultiModalEmbeddings,
*, *,
is_multimodal: torch.Tensor, is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ... ) -> Tensor: ...
def _embed_text_input_ids( def _embed_text_input_ids(
...@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol): ...@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
embed_input_ids: Callable[[Tensor], Tensor], embed_input_ids: Callable[[Tensor], Tensor],
*, *,
is_multimodal: Tensor | None, is_multimodal: Tensor | None,
handle_oov_mm_token: bool,
) -> Tensor: ) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None: if is_multimodal is not None and self._has_oov_mm_tokens:
is_text = ~is_multimodal # Force all input IDs to be in vocab; we do this instead of squeezing
text_embeds = embed_input_ids(input_ids[is_text]) # to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
return torch.empty( # we have multimodal tokens.
(input_ids.shape[0], text_embeds.shape[1]), in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
dtype=text_embeds.dtype, return embed_input_ids(in_vocab_ids)
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return embed_input_ids(input_ids) return embed_input_ids(input_ids)
...@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol): ...@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: Tensor | None = None, is_multimodal: Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> Tensor: ) -> Tensor:
""" """
Apply token embeddings to `input_ids`. Apply token embeddings to `input_ids`.
...@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol): ...@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
If `multimodal_embeddings` is passed, scatter them into If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`. `input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of NOTE: If this model has multimodal tokens that are of vocabulary
the language model, you can set `handle_oov_mm_token=False` (i.e., self._has_oov_mm_tokens=True), the input_ids will be copied
to avoid calling the language model's `embed_input_ids` method and masked to 0 during the forward pass for the text embeddings.
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
""" """
from .utils import _merge_multimodal_embeddings from .utils import _merge_multimodal_embeddings
# Get text embeddings first; multimodal embeddings will clobber
# any invalid contents in the indices of multimodal embeddings
# for the in vocabulary and out of vocabulary case.
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.get_language_model().embed_input_ids, self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration( ...@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
...@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration( ...@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
...@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration( ...@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
# NOTE: These are special cases for Pixtral-12B in the HF-format # NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if ( if (
......
...@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava( self.vision_tower = init_vision_tower_for_llava(
config, config,
...@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration( ...@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.get_language_model().embed_input_ids, self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
...@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.embed_tokens, self.embed_tokens,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids) return super().embed_input_ids(input_ids)
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens # Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region. Only use the interleaved path when # in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation. # needed; otherwise fall back to the default parent implementation.
...@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids, input_ids,
self.get_language_model().embed_input_ids, self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
return merge_interleaved_embeddings( return merge_interleaved_embeddings(
inputs_embeds, inputs_embeds,
...@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) ...@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.embed_input_ids, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal): ...@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.model.embed_input_ids, self.model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration( ...@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.embed_input_ids, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.embed_input_ids, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
...@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment