Unverified Commit bd2659a5 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

Increase Flexibility for OOV Multimodal Token Handling (#34858)


Signed-off-by: default avatarAlex Brooks <albrooks@redhat.com>
parent 90512b2e
...@@ -2301,13 +2301,11 @@ class Qwen3VLForConditionalGeneration( ...@@ -2301,13 +2301,11 @@ class Qwen3VLForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids( inputs_embeds = self._embed_text_input_ids(
input_ids, input_ids,
self.language_model.embed_input_ids, self.language_model.embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
......
...@@ -1184,13 +1184,11 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1184,13 +1184,11 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids: Callable[[torch.Tensor], torch.Tensor], embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*, *,
is_multimodal: torch.Tensor | None, is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids( inputs_embeds = super()._embed_text_input_ids(
input_ids, input_ids,
embed_input_ids, embed_input_ids,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
# NOTE: inputs_embeds in model runner has size text_config.projection_size # NOTE: inputs_embeds in model runner has size text_config.projection_size
...@@ -1219,7 +1217,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1219,7 +1217,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
self._is_text_input = ( self._is_text_input = (
multimodal_embeddings is None or len(multimodal_embeddings) == 0 multimodal_embeddings is None or len(multimodal_embeddings) == 0
...@@ -1232,7 +1229,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): ...@@ -1232,7 +1229,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
......
...@@ -877,7 +877,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -877,7 +877,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
...@@ -890,7 +889,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -890,7 +889,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -937,7 +937,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -937,7 +937,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -945,6 +944,19 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -945,6 +944,19 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
# NOTE: This behavior is consistent with the previous OOV handling,
# but does not currently handle the start/stop toks around the
# image features (<patch_start> <patch_end> <im_start> <im_end>)
# See: https://huggingface.co/stepfun-ai/step3/blob/main/processing_step3v.py#L323
#
# If this becomes an issue or we refactor to handle this using the
# processor info in the future, it would probably be best to handle
# those too.
self.configure_mm_token_handling(
self.config.text_config.vocab_size,
[self.config.image_token_id],
)
with self._mark_tower_model(vllm_config, "image"): with self._mark_tower_model(vllm_config, "image"):
self.vision_model = Step3VisionTransformer( self.vision_model = Step3VisionTransformer(
config.vision_config, config.vision_config,
...@@ -1080,8 +1092,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1080,8 +1092,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -1091,7 +1101,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1091,7 +1101,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -265,7 +265,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -265,7 +265,6 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings # We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in # to be calculated. However, due to the mandatory token ids in
......
...@@ -551,6 +551,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -551,6 +551,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self.multi_modal_config = multimodal_config self.multi_modal_config = multimodal_config
assert self.multi_modal_config assert self.multi_modal_config
self.configure_mm_token_handling(
self.config.vocab_size,
[self.config.audio_token_index],
)
self.secondary_weights = [] self.secondary_weights = []
if config.audio_model_id is not None: if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights # this prefix is not for initialization, but for loading weights
...@@ -707,8 +712,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -707,8 +712,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
# This is to satisfy the type checker for each overload # This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None: if multimodal_embeddings is None or is_multimodal is None:
...@@ -718,7 +721,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -718,7 +721,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
input_ids, input_ids,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
) )
def forward( def forward(
......
...@@ -298,7 +298,6 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -298,7 +298,6 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size # Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
"""Pass post-conv embeddings directly as input. """Pass post-conv embeddings directly as input.
......
...@@ -996,7 +996,6 @@ class WhisperForConditionalGeneration( ...@@ -996,7 +996,6 @@ class WhisperForConditionalGeneration(
multimodal_embeddings: MultiModalEmbeddings | None = None, multimodal_embeddings: MultiModalEmbeddings | None = None,
*, *,
is_multimodal: torch.Tensor | None = None, is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# This method just returns the decoder sequence embeddings since # This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. # Whisper does not have encoder text tokens.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment