Unverified Commit 4f6eed3b authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Core] Simplify multimodal masking (#34246)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 36d7f198
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
import pytest import pytest
import torch import torch
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import (
AutoWeightsLoader,
pytestmark = pytest.mark.cpu_test _merge_multimodal_embeddings,
)
from vllm.platforms import current_platform
class ModuleWithBatchNorm(torch.nn.Module): class ModuleWithBatchNorm(torch.nn.Module):
...@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module): ...@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
return self.nested_mod(x) return self.nested_mod(x)
@pytest.mark.cpu_test
def test_module_with_batchnorm_can_load(): def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats.""" """Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm() mod = ModuleWithBatchNorm()
...@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load(): ...@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
assert new_mod.bn.num_batches_tracked.item() == 1 assert new_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_with_child_containing_batchnorm_can_autoload(): def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats.""" """Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm() mod = ModuleWithNestedBatchNorm()
...@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload(): ...@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_prefix(): def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix.""" """Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm() mod = ModuleWithNestedBatchNorm()
...@@ -119,6 +124,7 @@ def test_module_skip_prefix(): ...@@ -119,6 +124,7 @@ def test_module_skip_prefix():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_substr(): def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix.""" """Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm() mod = ModuleWithNestedBatchNorm()
...@@ -155,3 +161,23 @@ def test_module_skip_substr(): ...@@ -155,3 +161,23 @@ def test_module_skip_substr():
) )
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
class raise_if_cuda_sync:
def __enter__(self):
self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode("error")
def __exit__(self, exception_type, exception_value, traceback):
torch.cuda.set_sync_debug_mode(self.previous_debug_mode)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_merge_multimodal_embeddings_no_sync():
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0")
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")]
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
with raise_if_cuda_sync():
_merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)
...@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol): ...@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
# to ensure that any external configuration requiring offset tracking, # to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not # e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens. # we have multimodal tokens.
in_vocab_ids = input_ids.masked_fill(is_multimodal, 0) in_vocab_ids = input_ids.masked_fill(
is_multimodal.to(device=input_ids.device, non_blocking=True), 0
)
return embed_input_ids(in_vocab_ids) return embed_input_ids(in_vocab_ids)
return embed_input_ids(input_ids) return embed_input_ids(input_ids)
......
...@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2( ...@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
These embeddings will replace the placeholder embeddings to create These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM. input_embeds for the LLM.
""" """
device = video_embeddings.device
tokenizer = cached_tokenizer_from_config(self.model_config) tokenizer = cached_tokenizer_from_config(self.model_config)
# Generate video replacement token IDs using get_video_repl # Generate video replacement token IDs using get_video_repl
...@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2( ...@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
) )
# video_repl.full is a list of token IDs # video_repl.full is a list of token IDs
repl_token_ids = torch.tensor(video_repl.full, device=device) repl_token_ids = torch.tensor(video_repl.full)
# Get embedding token IDs for image context (use pre-tokenized version) # Get embedding token IDs for image context (use pre-tokenized version)
embed_token_ids = torch.tensor(self._img_context_token_ids, device=device) embed_token_ids = torch.tensor(self._img_context_token_ids)
# Create mask for video embedding positions # Create mask for video embedding positions
is_video_embed = torch.isin(repl_token_ids, embed_token_ids) is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
......
...@@ -211,15 +211,12 @@ def merge_interleaved_embeddings( ...@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
# Scatter each modality to its positions # Scatter each modality to its positions
if video_embeds: if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0] inputs_embeds[is_video] = torch.cat(video_embeds, dim=0)
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
if audio_embeds: if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0] inputs_embeds[is_audio] = torch.cat(audio_embeds, dim=0)
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
if other_embeds: if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0] inputs_embeds[other_mask] = torch.cat(other_embeds, dim=0)
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
return inputs_embeds return inputs_embeds
...@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
video_token_id = self.config.video_token_index video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index audio_token_id = self.config.audio_token_index
is_video = is_multimodal & (input_ids == video_token_id) input_ids_cpu = input_ids.cpu()
is_audio = is_multimodal & (input_ids == audio_token_id) is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item() num_video = is_video.sum().item()
num_audio = is_audio.sum().item() num_audio = is_audio.sum().item()
......
...@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
# both the deepstack path and the final embedding merge. # both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id) input_ids_cpu = input_ids.cpu()
is_audio = is_multimodal & (input_ids == audio_token_id) is_video = is_multimodal & (input_ids_cpu == video_token_id)
is_audio = is_multimodal & (input_ids_cpu == audio_token_id)
num_video = is_video.sum().item() num_video = is_video.sum().item()
num_audio = is_audio.sum().item() num_audio = is_audio.sum().item()
......
...@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration( ...@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
These embeddings will replace the placeholder embeddings to create These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM. input_embeds for the LLM.
""" """
device = video_embeddings.device
# Generate video replacement token IDs using get_video_repl # Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized # This tokenizes each frame separator independently, then uses pre-tokenized
...@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration( ...@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
select_token_id=self.is_multimodal_pruning_enabled, select_token_id=self.is_multimodal_pruning_enabled,
) )
repl_token_ids = torch.tensor(video_repl.full, device=device) repl_token_ids = torch.tensor(video_repl.full)
embed_token_id = _cached_tensor(self.config.video_token_id, device=device) embed_token_id = _cached_tensor(
self.config.video_token_id, repl_token_ids.device
)
is_video_embed = torch.isin(repl_token_ids, embed_token_id) is_video_embed = torch.isin(repl_token_ids, embed_token_id)
# Get text embeddings for indicator tokens (has only `visual_dim``). # Get text embeddings for indicator tokens (has only `visual_dim``).
......
...@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings( ...@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
input_dtype = inputs_embeds.dtype input_dtype = inputs_embeds.dtype
try: try:
# For debugging # If is_multimodal is on CPU this avoids a D2H sync
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype) inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e: except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat) num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item() num_expected_tokens = is_multimodal.sum().item()
...@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings( ...@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
f"multimodal tokens to {num_expected_tokens} placeholders" f"multimodal tokens to {num_expected_tokens} placeholders"
) from e ) from e
raise ValueError("Error during masked scatter operation") from e raise ValueError("Error during index put operation") from e
return inputs_embeds return inputs_embeds
......
...@@ -83,7 +83,7 @@ class EncoderRunner: ...@@ -83,7 +83,7 @@ class EncoderRunner:
mm_embeds: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros( is_mm_embed = torch.zeros(
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True total_num_scheduled_tokens, dtype=torch.bool, device="cpu"
) )
for i, req_id in enumerate(req_ids): for i, req_id in enumerate(req_ids):
if not is_prefilling[i]: if not is_prefilling[i]:
...@@ -131,8 +131,6 @@ class EncoderRunner: ...@@ -131,8 +131,6 @@ class EncoderRunner:
) )
mm_embeds.append(mm_embeds_item) mm_embeds.append(mm_embeds_item)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
return mm_embeds, is_mm_embed return mm_embeds, is_mm_embed
@torch.inference_mode() @torch.inference_mode()
......
...@@ -719,16 +719,6 @@ class GPUModelRunner( ...@@ -719,16 +719,6 @@ class GPUModelRunner(
self.max_num_reqs, dtype=torch.int32 self.max_num_reqs, dtype=torch.int32
) )
# Only relevant for multimodal models
if self.supports_mm_inputs:
# Double buffer to avoid race condition: previous iteration's async
# copy may still be reading from CPU while current iteration writes.
self.is_mm_embed_buffers = [
self._make_buffer(self.max_num_tokens, dtype=torch.bool),
self._make_buffer(self.max_num_tokens, dtype=torch.bool),
]
self.is_mm_embed_idx = 0
# Only relevant for models using M-RoPE (e.g, Qwen2-VL) # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope: if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy # NOTE: `mrope_positions` is implemented with one additional dummy
...@@ -2910,14 +2900,10 @@ class GPUModelRunner( ...@@ -2910,14 +2900,10 @@ class GPUModelRunner(
) -> tuple[list[torch.Tensor], torch.Tensor]: ) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Swap to the other buffer to avoid race condition with previous
# iteration's async copy that may still be reading from CPU.
self.is_mm_embed_idx = 1 - self.is_mm_embed_idx
is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx]
mm_embeds = list[torch.Tensor]() mm_embeds = list[torch.Tensor]()
is_mm_embed = is_mm_embed_buf.cpu is_mm_embed = torch.zeros(
is_mm_embed[:total_num_scheduled_tokens] = False total_num_scheduled_tokens, dtype=torch.bool, device="cpu"
)
req_start_idx = 0 req_start_idx = 0
should_sync_mrope_positions = False should_sync_mrope_positions = False
...@@ -3000,8 +2986,6 @@ class GPUModelRunner( ...@@ -3000,8 +2986,6 @@ class GPUModelRunner(
mm_embeds.extend(mm_embeds_req) mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens req_start_idx += num_scheduled_tokens
is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions: if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output) self._calc_mrope_positions(scheduler_output)
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens) self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment