Unverified Commit 4f6eed3b authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Core] Simplify multimodal masking (#34246)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 36d7f198
......@@ -4,9 +4,11 @@
import pytest
import torch
from vllm.model_executor.models.utils import AutoWeightsLoader
pytestmark = pytest.mark.cpu_test
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
_merge_multimodal_embeddings,
)
from vllm.platforms import current_platform
class ModuleWithBatchNorm(torch.nn.Module):
......@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
return self.nested_mod(x)
@pytest.mark.cpu_test
def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm()
......@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
assert new_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm()
......@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
......@@ -119,6 +124,7 @@ def test_module_skip_prefix():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
......@@ -155,3 +161,23 @@ def test_module_skip_substr():
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
class raise_if_cuda_sync:
def __enter__(self):
self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode("error")
def __exit__(self, exception_type, exception_value, traceback):
torch.cuda.set_sync_debug_mode(self.previous_debug_mode)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_merge_multimodal_embeddings_no_sync():
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0")
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")]
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
with raise_if_cuda_sync():
_merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)
......@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens.
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0)
in_vocab_ids = input_ids.masked_fill(
is_multimodal.to(device=input_ids.device, non_blocking=True), 0
)
return embed_input_ids(in_vocab_ids)
return embed_input_ids(input_ids)
......
......@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
tokenizer = cached_tokenizer_from_config(self.model_config)
# Generate video replacement token IDs using get_video_repl
......@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
)
# video_repl.full is a list of token IDs
repl_token_ids = torch.tensor(video_repl.full, device=device)
repl_token_ids = torch.tensor(video_repl.full)
# Get embedding token IDs for image context (use pre-tokenized version)
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device)
embed_token_ids = torch.tensor(self._img_context_token_ids)
# Create mask for video embedding positions
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
......
......@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
# Scatter each modality to its positions
if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0]
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
inputs_embeds[is_video] = torch.cat(video_embeds, dim=0)
if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0]
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
inputs_embeds[is_audio] = torch.cat(audio_embeds, dim=0)
if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0]
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
inputs_embeds[other_mask] = torch.cat(other_embeds, dim=0)
return inputs_embeds
......@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
input_ids_cpu = input_ids.cpu()
is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
......
......@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
input_ids_cpu = input_ids.cpu()
is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
......
......@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
"""
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
......@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
select_token_id=self.is_multimodal_pruning_enabled,
)
repl_token_ids = torch.tensor(video_repl.full, device=device)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device)
repl_token_ids = torch.tensor(video_repl.full)
embed_token_id = _cached_tensor(
self.config.video_token_id, repl_token_ids.device
)
is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``).
......
......@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
input_dtype = inputs_embeds.dtype
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
# If is_multimodal is on CPU this avoids a D2H sync
inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
......@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
raise ValueError("Error during masked scatter operation") from e
raise ValueError("Error during index put operation") from e
return inputs_embeds
......
......@@ -83,7 +83,7 @@ class EncoderRunner:
mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
total_num_scheduled_tokens, dtype=torch.bool, device="cpu"
)
for i, req_id in enumerate(req_ids):
if not is_prefilling[i]:
......@@ -131,8 +131,6 @@ class EncoderRunner:
)
mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
return mm_embeds, is_mm_embed
@torch.inference_mode()
......
......@@ -719,16 +719,6 @@ class GPUModelRunner(
self.max_num_reqs, dtype=torch.int32
)
# Only relevant for multimodal models
if self.supports_mm_inputs:
# Double buffer to avoid race condition: previous iteration's async
# copy may still be reading from CPU while current iteration writes.
self.is_mm_embed_buffers = [
self._make_buffer(self.max_num_tokens, dtype=torch.bool),
self._make_buffer(self.max_num_tokens, dtype=torch.bool),
]
self.is_mm_embed_idx = 0
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
......@@ -2910,14 +2900,10 @@ class GPUModelRunner(
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Swap to the other buffer to avoid race condition with previous
# iteration's async copy that may still be reading from CPU.
self.is_mm_embed_idx = 1 - self.is_mm_embed_idx
is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx]
mm_embeds = list[torch.Tensor]()
is_mm_embed = is_mm_embed_buf.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
is_mm_embed = torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device="cpu"
)
req_start_idx = 0
should_sync_mrope_positions = False
......@@ -3000,8 +2986,6 @@ class GPUModelRunner(
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment