Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f6eed3b
Unverified
Commit
4f6eed3b
authored
Apr 01, 2026
by
Lukas Geiger
Committed by
GitHub
Apr 01, 2026
Browse files
[Core] Simplify multimodal masking (#34246)
Signed-off-by:
Lukas Geiger
<
lukas.geiger94@gmail.com
>
parent
36d7f198
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
51 deletions
+54
-51
tests/models/test_utils.py
tests/models/test_utils.py
+29
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+3
-1
vllm/model_executor/models/nano_nemotron_vl.py
vllm/model_executor/models/nano_nemotron_vl.py
+2
-3
vllm/model_executor/models/qwen2_5_omni_thinker.py
vllm/model_executor/models/qwen2_5_omni_thinker.py
+6
-8
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+3
-2
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+4
-3
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+3
-9
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+1
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-19
No files found.
tests/models/test_utils.py
View file @
4f6eed3b
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
pytestmark
=
pytest
.
mark
.
cpu_test
_merge_multimodal_embeddings
,
)
from
vllm.platforms
import
current_platform
class
ModuleWithBatchNorm
(
torch
.
nn
.
Module
):
class
ModuleWithBatchNorm
(
torch
.
nn
.
Module
):
...
@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
...
@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
return
self
.
nested_mod
(
x
)
return
self
.
nested_mod
(
x
)
@
pytest
.
mark
.
cpu_test
def
test_module_with_batchnorm_can_load
():
def
test_module_with_batchnorm_can_load
():
"""Ensure the auto weight loader can load batchnorm stats."""
"""Ensure the auto weight loader can load batchnorm stats."""
mod
=
ModuleWithBatchNorm
()
mod
=
ModuleWithBatchNorm
()
...
@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
...
@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
assert
new_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
assert
new_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
@
pytest
.
mark
.
cpu_test
def
test_module_with_child_containing_batchnorm_can_autoload
():
def
test_module_with_child_containing_batchnorm_can_autoload
():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod
=
ModuleWithNestedBatchNorm
()
mod
=
ModuleWithNestedBatchNorm
()
...
@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
...
@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
@
pytest
.
mark
.
cpu_test
def
test_module_skip_prefix
():
def
test_module_skip_prefix
():
"""Ensure the auto weight loader can skip prefix."""
"""Ensure the auto weight loader can skip prefix."""
mod
=
ModuleWithNestedBatchNorm
()
mod
=
ModuleWithNestedBatchNorm
()
...
@@ -119,6 +124,7 @@ def test_module_skip_prefix():
...
@@ -119,6 +124,7 @@ def test_module_skip_prefix():
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
@
pytest
.
mark
.
cpu_test
def
test_module_skip_substr
():
def
test_module_skip_substr
():
"""Ensure the auto weight loader can skip prefix."""
"""Ensure the auto weight loader can skip prefix."""
mod
=
ModuleWithNestedBatchNorm
()
mod
=
ModuleWithNestedBatchNorm
()
...
@@ -155,3 +161,23 @@ def test_module_skip_substr():
...
@@ -155,3 +161,23 @@ def test_module_skip_substr():
)
)
assert
torch
.
all
(
new_mod
.
nested_mod
.
bn
.
running_var
==
mod
.
nested_mod
.
bn
.
running_var
)
assert
torch
.
all
(
new_mod
.
nested_mod
.
bn
.
running_var
==
mod
.
nested_mod
.
bn
.
running_var
)
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
assert
new_mod
.
nested_mod
.
bn
.
num_batches_tracked
.
item
()
==
1
class
raise_if_cuda_sync
:
def
__enter__
(
self
):
self
.
previous_debug_mode
=
torch
.
cuda
.
get_sync_debug_mode
()
torch
.
cuda
.
set_sync_debug_mode
(
"error"
)
def
__exit__
(
self
,
exception_type
,
exception_value
,
traceback
):
torch
.
cuda
.
set_sync_debug_mode
(
self
.
previous_debug_mode
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
def
test_merge_multimodal_embeddings_no_sync
():
inputs_embeds
=
torch
.
zeros
([
5
,
10
],
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
multimodal_embeddings
=
[
torch
.
ones
([
3
,
10
],
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)]
is_multimodal
=
torch
.
tensor
([
True
,
False
,
True
,
True
,
False
],
device
=
"cpu"
)
with
raise_if_cuda_sync
():
_merge_multimodal_embeddings
(
inputs_embeds
,
multimodal_embeddings
,
is_multimodal
)
vllm/model_executor/models/interfaces.py
View file @
4f6eed3b
...
@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
...
@@ -362,7 +362,9 @@ class SupportsMultiModal(Protocol):
# to ensure that any external configuration requiring offset tracking,
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
# e.g., LoRA, are applied correctly regardless of whether or not
# we have multimodal tokens.
# we have multimodal tokens.
in_vocab_ids
=
input_ids
.
masked_fill
(
is_multimodal
,
0
)
in_vocab_ids
=
input_ids
.
masked_fill
(
is_multimodal
.
to
(
device
=
input_ids
.
device
,
non_blocking
=
True
),
0
)
return
embed_input_ids
(
in_vocab_ids
)
return
embed_input_ids
(
in_vocab_ids
)
return
embed_input_ids
(
input_ids
)
return
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/nano_nemotron_vl.py
View file @
4f6eed3b
...
@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
...
@@ -1215,7 +1215,6 @@ class NemotronH_Nano_VL_V2(
These embeddings will replace the placeholder embeddings to create
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
input_embeds for the LLM.
"""
"""
device
=
video_embeddings
.
device
tokenizer
=
cached_tokenizer_from_config
(
self
.
model_config
)
tokenizer
=
cached_tokenizer_from_config
(
self
.
model_config
)
# Generate video replacement token IDs using get_video_repl
# Generate video replacement token IDs using get_video_repl
...
@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
...
@@ -1234,10 +1233,10 @@ class NemotronH_Nano_VL_V2(
)
)
# video_repl.full is a list of token IDs
# video_repl.full is a list of token IDs
repl_token_ids
=
torch
.
tensor
(
video_repl
.
full
,
device
=
device
)
repl_token_ids
=
torch
.
tensor
(
video_repl
.
full
)
# Get embedding token IDs for image context (use pre-tokenized version)
# Get embedding token IDs for image context (use pre-tokenized version)
embed_token_ids
=
torch
.
tensor
(
self
.
_img_context_token_ids
,
device
=
device
)
embed_token_ids
=
torch
.
tensor
(
self
.
_img_context_token_ids
)
# Create mask for video embedding positions
# Create mask for video embedding positions
is_video_embed
=
torch
.
isin
(
repl_token_ids
,
embed_token_ids
)
is_video_embed
=
torch
.
isin
(
repl_token_ids
,
embed_token_ids
)
...
...
vllm/model_executor/models/qwen2_5_omni_thinker.py
View file @
4f6eed3b
...
@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
...
@@ -211,15 +211,12 @@ def merge_interleaved_embeddings(
# Scatter each modality to its positions
# Scatter each modality to its positions
if
video_embeds
:
if
video_embeds
:
video_positions
=
is_video
.
nonzero
(
as_tuple
=
True
)[
0
]
inputs_embeds
[
is_video
]
=
torch
.
cat
(
video_embeds
,
dim
=
0
)
inputs_embeds
[
video_positions
]
=
torch
.
cat
(
video_embeds
,
dim
=
0
)
if
audio_embeds
:
if
audio_embeds
:
audio_positions
=
is_audio
.
nonzero
(
as_tuple
=
True
)[
0
]
inputs_embeds
[
is_audio
]
=
torch
.
cat
(
audio_embeds
,
dim
=
0
)
inputs_embeds
[
audio_positions
]
=
torch
.
cat
(
audio_embeds
,
dim
=
0
)
if
other_embeds
:
if
other_embeds
:
other_mask
=
is_multimodal
&
~
is_video
&
~
is_audio
other_mask
=
is_multimodal
&
~
is_video
&
~
is_audio
other_positions
=
other_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
inputs_embeds
[
other_mask
]
=
torch
.
cat
(
other_embeds
,
dim
=
0
)
inputs_embeds
[
other_positions
]
=
torch
.
cat
(
other_embeds
,
dim
=
0
)
return
inputs_embeds
return
inputs_embeds
...
@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
...
@@ -1457,8 +1454,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
video_token_id
=
self
.
config
.
video_token_index
video_token_id
=
self
.
config
.
video_token_index
audio_token_id
=
self
.
config
.
audio_token_index
audio_token_id
=
self
.
config
.
audio_token_index
is_video
=
is_multimodal
&
(
input_ids
==
video_token_id
)
input_ids_cpu
=
input_ids
.
cpu
()
is_audio
=
is_multimodal
&
(
input_ids
==
audio_token_id
)
is_video
=
is_multimodal
&
(
input_ids_cpu
==
video_token_id
)
is_audio
=
is_multimodal
&
(
input_ids_cpu
==
audio_token_id
)
num_video
=
is_video
.
sum
().
item
()
num_video
=
is_video
.
sum
().
item
()
num_audio
=
is_audio
.
sum
().
item
()
num_audio
=
is_audio
.
sum
().
item
()
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
4f6eed3b
...
@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
...
@@ -1869,8 +1869,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
# both the deepstack path and the final embedding merge.
# both the deepstack path and the final embedding merge.
video_token_id
=
self
.
config
.
video_token_id
video_token_id
=
self
.
config
.
video_token_id
audio_token_id
=
self
.
config
.
audio_token_id
audio_token_id
=
self
.
config
.
audio_token_id
is_video
=
is_multimodal
&
(
input_ids
==
video_token_id
)
input_ids_cpu
=
input_ids
.
cpu
()
is_audio
=
is_multimodal
&
(
input_ids
==
audio_token_id
)
is_video
=
is_multimodal
&
(
input_ids_cpu
==
video_token_id
)
is_audio
=
is_multimodal
&
(
input_ids_cpu
==
audio_token_id
)
num_video
=
is_video
.
sum
().
item
()
num_video
=
is_video
.
sum
().
item
()
num_audio
=
is_audio
.
sum
().
item
()
num_audio
=
is_audio
.
sum
().
item
()
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
4f6eed3b
...
@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
...
@@ -1977,7 +1977,6 @@ class Qwen3VLForConditionalGeneration(
These embeddings will replace the placeholder embeddings to create
These embeddings will replace the placeholder embeddings to create
input_embeds for the LLM.
input_embeds for the LLM.
"""
"""
device
=
video_embeddings
.
device
# Generate video replacement token IDs using get_video_repl
# Generate video replacement token IDs using get_video_repl
# This tokenizes each frame separator independently, then uses pre-tokenized
# This tokenizes each frame separator independently, then uses pre-tokenized
...
@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
...
@@ -1993,8 +1992,10 @@ class Qwen3VLForConditionalGeneration(
select_token_id
=
self
.
is_multimodal_pruning_enabled
,
select_token_id
=
self
.
is_multimodal_pruning_enabled
,
)
)
repl_token_ids
=
torch
.
tensor
(
video_repl
.
full
,
device
=
device
)
repl_token_ids
=
torch
.
tensor
(
video_repl
.
full
)
embed_token_id
=
_cached_tensor
(
self
.
config
.
video_token_id
,
device
=
device
)
embed_token_id
=
_cached_tensor
(
self
.
config
.
video_token_id
,
repl_token_ids
.
device
)
is_video_embed
=
torch
.
isin
(
repl_token_ids
,
embed_token_id
)
is_video_embed
=
torch
.
isin
(
repl_token_ids
,
embed_token_id
)
# Get text embeddings for indicator tokens (has only `visual_dim``).
# Get text embeddings for indicator tokens (has only `visual_dim``).
...
...
vllm/model_executor/models/utils.py
View file @
4f6eed3b
...
@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
...
@@ -468,14 +468,8 @@ def _merge_multimodal_embeddings(
input_dtype
=
inputs_embeds
.
dtype
input_dtype
=
inputs_embeds
.
dtype
try
:
try
:
# For debugging
# If is_multimodal is on CPU this avoids a D2H sync
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
inputs_embeds
[
is_multimodal
]
=
mm_embeds_flat
.
to
(
dtype
=
input_dtype
)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds
.
masked_scatter_
(
is_multimodal
.
unsqueeze
(
-
1
),
mm_embeds_flat
.
to
(
dtype
=
input_dtype
)
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
num_actual_tokens
=
len
(
mm_embeds_flat
)
num_actual_tokens
=
len
(
mm_embeds_flat
)
num_expected_tokens
=
is_multimodal
.
sum
().
item
()
num_expected_tokens
=
is_multimodal
.
sum
().
item
()
...
@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
...
@@ -488,7 +482,7 @@ def _merge_multimodal_embeddings(
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
from
e
)
from
e
raise
ValueError
(
"Error during
masked scatter
operation"
)
from
e
raise
ValueError
(
"Error during
index put
operation"
)
from
e
return
inputs_embeds
return
inputs_embeds
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
4f6eed3b
...
@@ -83,7 +83,7 @@ class EncoderRunner:
...
@@ -83,7 +83,7 @@ class EncoderRunner:
mm_embeds
:
list
[
torch
.
Tensor
]
=
[]
mm_embeds
:
list
[
torch
.
Tensor
]
=
[]
is_mm_embed
=
torch
.
zeros
(
is_mm_embed
=
torch
.
zeros
(
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
True
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
)
for
i
,
req_id
in
enumerate
(
req_ids
):
for
i
,
req_id
in
enumerate
(
req_ids
):
if
not
is_prefilling
[
i
]:
if
not
is_prefilling
[
i
]:
...
@@ -131,8 +131,6 @@ class EncoderRunner:
...
@@ -131,8 +131,6 @@ class EncoderRunner:
)
)
mm_embeds
.
append
(
mm_embeds_item
)
mm_embeds
.
append
(
mm_embeds_item
)
# Copy the is_mm_embed tensor to the GPU.
is_mm_embed
=
is_mm_embed
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
return
mm_embeds
,
is_mm_embed
return
mm_embeds
,
is_mm_embed
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4f6eed3b
...
@@ -719,16 +719,6 @@ class GPUModelRunner(
...
@@ -719,16 +719,6 @@ class GPUModelRunner(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
)
# Only relevant for multimodal models
if
self
.
supports_mm_inputs
:
# Double buffer to avoid race condition: previous iteration's async
# copy may still be reading from CPU while current iteration writes.
self
.
is_mm_embed_buffers
=
[
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
bool
),
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
bool
),
]
self
.
is_mm_embed_idx
=
0
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
# NOTE: `mrope_positions` is implemented with one additional dummy
# NOTE: `mrope_positions` is implemented with one additional dummy
...
@@ -2910,14 +2900,10 @@ class GPUModelRunner(
...
@@ -2910,14 +2900,10 @@ class GPUModelRunner(
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# Swap to the other buffer to avoid race condition with previous
# iteration's async copy that may still be reading from CPU.
self
.
is_mm_embed_idx
=
1
-
self
.
is_mm_embed_idx
is_mm_embed_buf
=
self
.
is_mm_embed_buffers
[
self
.
is_mm_embed_idx
]
mm_embeds
=
list
[
torch
.
Tensor
]()
mm_embeds
=
list
[
torch
.
Tensor
]()
is_mm_embed
=
is_mm_embed_buf
.
cpu
is_mm_embed
=
torch
.
zeros
(
is_mm_embed
[:
total_num_scheduled_tokens
]
=
False
total_num_scheduled_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
req_start_idx
=
0
req_start_idx
=
0
should_sync_mrope_positions
=
False
should_sync_mrope_positions
=
False
...
@@ -3000,8 +2986,6 @@ class GPUModelRunner(
...
@@ -3000,8 +2986,6 @@ class GPUModelRunner(
mm_embeds
.
extend
(
mm_embeds_req
)
mm_embeds
.
extend
(
mm_embeds_req
)
req_start_idx
+=
num_scheduled_tokens
req_start_idx
+=
num_scheduled_tokens
is_mm_embed
=
is_mm_embed_buf
.
copy_to_gpu
(
total_num_scheduled_tokens
)
if
should_sync_mrope_positions
:
if
should_sync_mrope_positions
:
self
.
_calc_mrope_positions
(
scheduler_output
)
self
.
_calc_mrope_positions
(
scheduler_output
)
self
.
mrope_positions
.
copy_to_gpu
(
total_num_scheduled_tokens
)
self
.
mrope_positions
.
copy_to_gpu
(
total_num_scheduled_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment