Unverified Commit bd2659a5 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

Increase Flexibility for OOV Multimodal Token Handling (#34858)


Signed-off-by: default avatarAlex Brooks <albrooks@redhat.com>
parent 90512b2e
......@@ -2301,13 +2301,11 @@ class Qwen3VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
......@@ -1184,13 +1184,11 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_size
......@@ -1219,7 +1217,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0
......@@ -1232,7 +1229,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
......
......@@ -877,7 +877,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
......@@ -890,7 +889,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -937,7 +937,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
......@@ -945,6 +944,19 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
# NOTE: This behavior is consistent with the previous OOV handling,
# but does not currently handle the start/stop toks around the
# image features (<patch_start> <patch_end> <im_start> <im_end>)
# See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
#
# If this becomes an issue or we refactor to handle this using the
# processor info in the future, it would probably be best to handle
# those too.
self.configure_mm_token_handling(
self.config.text_config.vocab_size,
[self.config.image_token_id],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer(
config.vision_config,
......@@ -1080,8 +1092,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
......@@ -1091,7 +1101,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -265,7 +265,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
......
......@@ -551,6 +551,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self.multi_modal_config = multimodal_config
assert self.multi_modal_config
self.configure_mm_token_handling(
self.config.vocab_size,
[self.config.audio_token_index],
)
self.secondary_weights = []
if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights
......@@ -707,8 +712,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
......@@ -718,7 +721,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
......
......@@ -298,7 +298,6 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
*,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input.
......
......@@ -996,7 +996,6 @@ class WhisperForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment