"docs/vscode:/vscode.git/clone" did not exist on "fba06427043be95a6c3d4329aeba8b519c449f23"
Unverified Commit bd2659a5 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

Increase Flexibility for OOV Multimodal Token Handling (#34858)


Signed-off-by: default avatarAlex Brooks <albrooks@redhat.com>
parent 90512b2e
......@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
......@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0
......@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
......
......@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""Embed input IDs with optional multimodal embeddings."""
if multimodal_embeddings is None or is_multimodal is None:
......@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
......@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
......
......@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
......@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# Early return for text-only inference (no multimodal data)
if multimodal_embeddings is None or is_multimodal is None:
......@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
......@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
self.quant_config = quant_config
self.cache_config = cache_config
# Check for OOV tokens to see if offsets need to be preserved
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.audio_token_index],
)
with self._mark_language_model(vllm_config):
# The language model is typically a Granite LLM
self.language_model = init_vllm_registered_model(
......@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
......@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
Set internally by `_mark_tower_model`.
"""
_has_oov_mm_tokens: bool = False
"""
In general, this should be set at init time by invoking
`configure_mm_token_handling` models & passing all potentially
OOV multimodal tokens.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""
......@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
"""
...
def configure_mm_token_handling(self, vocab_size: int, mm_token_ids: list[int]):
"""Check if any multimodal tokens are out of vocabulary. If so, we will
explicitly mask all multimodal tokens out when computing text embeddings,
since the multimodal embeddings will be scattered over the results.
"""
self._has_oov_mm_tokens = any(tok_id >= vocab_size for tok_id in mm_token_ids)
logger.info(
"Contains out of vocabulary multimodal tokens? %s",
self._has_oov_mm_tokens,
)
def get_language_model(self) -> VllmModel:
"""
Returns the underlying language model used for text generation.
......@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings,
*,
is_multimodal: torch.Tensor,
handle_oov_mm_token: bool = False,
) -> Tensor: ...
def _embed_text_input_ids(
......@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
embed_input_ids: Callable[[Tensor], Tensor],
*,
is_multimodal: Tensor | None,
handle_oov_mm_token: bool,
) -> Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = embed_input_ids(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
if is_multimodal is not None and self._has_oov_mm_tokens:
# Force all input IDs to be in vocab; we do this instead of squeezing
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens.
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
return embed_input_ids(in_vocab_ids)
return embed_input_ids(input_ids)
......@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> Tensor:
"""
Apply token embeddings to `input_ids`.
......@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `embed_input_ids` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
NOTE: If this model has multimodal tokens that are of vocabulary
(i.e., self._has_oov_mm_tokens=True), the input_ids will be copied
and masked to 0 during the forward pass for the text embeddings.
"""
from .utils import _merge_multimodal_embeddings
# Get text embeddings first; multimodal embeddings will clobber
# any invalid contents in the indices of multimodal embeddings
# for the in vocabulary and out of vocabulary case.
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
......@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
......@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (
......
......@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.config = config
self.multimodal_config = multimodal_config
self.configure_mm_token_handling(
vocab_size=config.text_config.vocab_size,
mm_token_ids=[config.image_token_index],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
......@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
......@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
......@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.embed_tokens,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation.
......@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
return merge_interleaved_embeddings(
inputs_embeds,
......@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment