Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
babad6e5
Unverified
Commit
babad6e5
authored
Sep 23, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 23, 2025
Browse files
[Misc] Move DP for ViT code inside model executor dir (#25459)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
9383cd6f
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
721 additions
and
730 deletions
+721
-730
tests/models/test_vision.py
tests/models/test_vision.py
+423
-1
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+1
-425
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+1
-2
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+2
-1
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+2
-1
vllm/model_executor/models/kimi_vl.py
vllm/model_executor/models/kimi_vl.py
+1
-1
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+1
-1
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+1
-2
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1
-2
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+1
-5
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+1
-1
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+280
-1
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+6
-287
No files found.
tests/models/test_vision.py
View file @
babad6e5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
from
vllm.model_executor.models.vision
import
resolve_visual_encoder_outputs
from
tests.utils
import
multi_gpu_test
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.model_executor.models.vision
import
(
get_load_balance_assignment
,
resolve_visual_encoder_outputs
,
run_dp_sharded_mrope_vision_model
,
run_dp_sharded_vision_model
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_open_port
,
update_environment_variables
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
...
@@ -33,3 +43,415 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
post_layer_norm
=
None
,
post_layer_norm
=
None
,
max_possible_layers
=
max_possible_layers
)
max_possible_layers
=
max_possible_layers
)
assert
torch
.
equal
(
torch
.
tensor
(
expected_features
),
output_tensor
)
assert
torch
.
equal
(
torch
.
tensor
(
expected_features
),
output_tensor
)
class
SimpleLinearModel
(
torch
.
nn
.
Module
):
"""A simple linear vision model for testing."""
def
__init__
(
self
,
input_dim
:
int
=
3
*
224
*
224
,
output_dim
:
int
=
32
):
super
().
__init__
()
self
.
flatten
=
torch
.
nn
.
Flatten
()
self
.
linear
=
torch
.
nn
.
Linear
(
input_dim
,
output_dim
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Flatten the input and apply linear transformation
x
=
self
.
flatten
(
x
)
return
self
.
linear
(
x
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
# Single image
4
,
# Small batch
5
,
# Odd batch size (for testing padding)
],
)
def
test_run_dp_sharded_vision_model
(
batch_size
:
int
):
world_size
=
2
# Launch processes
mp
.
spawn
(
run_dp_sharded_vision_model_vs_direct
,
args
=
(
world_size
,
batch_size
,
get_open_port
(),
),
nprocs
=
world_size
,
)
def
run_dp_sharded_vision_model_vs_direct
(
local_rank
:
int
,
world_size
:
int
,
batch_size
:
int
,
master_port
:
int
):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create a test input tensor
image_input
=
torch
.
randn
(
batch_size
,
3
,
224
,
224
)
# Create a simple linear model
vision_model
=
SimpleLinearModel
()
# Run the model directly on the full input
with
torch
.
inference_mode
():
direct_output
=
vision_model
(
image_input
)
# Run the model through the sharded function
with
torch
.
inference_mode
():
sharded_output
=
run_dp_sharded_vision_model
(
image_input
,
vision_model
)
# Check that the world size is set up correctly
assert
get_tensor_model_parallel_world_size
()
==
world_size
# Check that the outputs have the same shape
assert
direct_output
.
shape
==
sharded_output
.
shape
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description"
,
[
# Empty input
([],
2
,
[],
[
0
,
0
],
[
0
,
0
],
"empty input"
),
# Fewer samples than GPUs
([
100
,
200
],
4
,
[
1
,
0
],
[
1
,
1
,
0
,
0
],
[
200
,
100
,
0
,
0
],
"fewer samples than GPUs"
),
# Single GPU
([
100
,
200
,
300
],
1
,
[
2
,
1
,
0
],
[
3
],
[
600
],
"single GPU"
),
# Balanced assignment
([
100
,
100
,
100
,
100
],
2
,
[
0
,
2
,
1
,
3
],
[
2
,
2
],
[
200
,
200
],
"balanced assignment"
),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([
1000
,
100
,
200
,
50
],
2
,
[
0
,
2
,
1
,
3
],
[
1
,
3
],
[
1000
,
350
],
"unbalanced sizes"
),
],
)
def
test_get_load_balance_assignment_cases
(
sizes
,
num_gpus
,
expected_shuffle_indices
,
expected_gpu_sample_counts
,
expected_grouped_sizes_per_gpu
,
test_description
):
"""Test get_load_balance_assignment with various input cases."""
result
=
get_load_balance_assignment
(
sizes
,
num_gpus
=
num_gpus
)
(
shuffle_indices
,
gpu_sample_counts
,
grouped_sizes_per_gpu
)
=
result
# Common assertions for all cases
assert
len
(
shuffle_indices
)
==
len
(
sizes
)
assert
len
(
gpu_sample_counts
)
==
num_gpus
assert
len
(
grouped_sizes_per_gpu
)
==
num_gpus
assert
sum
(
gpu_sample_counts
)
==
len
(
sizes
)
assert
shuffle_indices
==
expected_shuffle_indices
assert
gpu_sample_counts
==
expected_gpu_sample_counts
assert
grouped_sizes_per_gpu
==
expected_grouped_sizes_per_gpu
class
SimpleMRopeVisionModel
(
torch
.
nn
.
Module
):
"""A simple vision model for testing mrope functionality."""
def
__init__
(
self
,
spatial_merge_size
:
int
=
2
,
out_hidden_size
:
int
=
64
):
super
().
__init__
()
self
.
spatial_merge_size
=
spatial_merge_size
self
.
out_hidden_size
=
out_hidden_size
self
.
linear
=
torch
.
nn
.
Linear
(
768
,
out_hidden_size
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thw_list
:
list
[
list
[
int
]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings
=
self
.
linear
(
pixel_values
)
# Simulate spatial merging by reducing the number of patches
merge_factor
=
self
.
spatial_merge_size
*
self
.
spatial_merge_size
# Group patches and merge spatially
merged_embeddings
=
[]
start_idx
=
0
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
end_idx
=
start_idx
+
num_patches
# Get patches for this image
image_patches
=
embeddings
[
start_idx
:
end_idx
]
# Simulate spatial merging by averaging groups of patches
merged_patches
=
num_patches
//
merge_factor
if
merged_patches
>
0
:
# Reshape and average to simulate merging
reshaped
=
image_patches
[:
merged_patches
*
merge_factor
].
view
(
merged_patches
,
merge_factor
,
-
1
)
merged
=
reshaped
.
mean
(
dim
=
1
)
merged_embeddings
.
append
(
merged
)
start_idx
=
end_idx
if
merged_embeddings
:
return
torch
.
cat
(
merged_embeddings
,
dim
=
0
)
else
:
return
torch
.
empty
((
0
,
self
.
out_hidden_size
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
# Single image
3
,
# Small batch
5
,
# Odd batch size (for testing padding)
],
)
def
test_run_dp_sharded_mrope_vision_model
(
batch_size
:
int
):
world_size
=
2
# Launch processes
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_vs_direct
,
args
=
(
world_size
,
batch_size
,
get_open_port
(),
),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_vs_direct
(
local_rank
:
int
,
world_size
:
int
,
batch_size
:
int
,
master_port
:
int
):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create test data
grid_thw_list
=
[]
pixel_values_list
=
[]
for
i
in
range
(
batch_size
):
# Varying image sizes for better testing
t
,
h
,
w
=
1
,
4
+
i
,
4
+
i
grid_thw_list
.
append
([
t
,
h
,
w
])
num_patches
=
t
*
h
*
w
# Create random pixel values for this image
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
# Concatenate all pixel values
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
# Create a simple mrope vision model
vision_model
=
SimpleMRopeVisionModel
()
# Run the model directly on the full input (only on rank 0)
if
local_rank
==
0
:
with
torch
.
inference_mode
():
direct_output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Run the model through the sharded function
with
torch
.
inference_mode
():
sharded_output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
sharded_output
=
torch
.
cat
(
sharded_output
,
dim
=
0
)
# Check that the world size is set up correctly
assert
get_tensor_model_parallel_world_size
()
==
world_size
# Compare outputs (only on rank 0)
if
local_rank
==
0
:
# Check that the outputs have the same shape
assert
direct_output
.
shape
==
sharded_output
.
shape
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_run_dp_sharded_mrope_vision_model_empty_input
():
world_size
=
2
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_empty_input_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_empty_input_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create empty inputs
pixel_values
=
torch
.
empty
((
0
,
768
))
grid_thw_list
:
list
[
list
[
int
]]
=
[]
vision_model
=
SimpleMRopeVisionModel
()
# Should handle empty input gracefully
with
torch
.
inference_mode
():
output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
assert
len
(
output
)
==
0
@
multi_gpu_test
(
num_gpus
=
4
)
def
test_run_dp_sharded_mrope_vision_model_uneven_load
():
world_size
=
4
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_uneven_load_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_uneven_load_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform
.
seed_everything
(
123
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create images with very different sizes
grid_thw_list
=
[
[
1
,
2
,
2
],
# Small: 4 patches
[
1
,
8
,
8
],
# Large: 64 patches
[
1
,
3
,
3
],
# Medium: 9 patches
]
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
()
# Should handle uneven distribution without errors
with
torch
.
inference_mode
():
output_tuple
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
# Verify output shape is reasonable
merge_factor
=
vision_model
.
spatial_merge_size
**
2
expected_output_patches
=
list
(
math
.
prod
(
grid_thw
)
//
merge_factor
for
grid_thw
in
grid_thw_list
)
for
i
,
output
in
enumerate
(
output_tuple
):
assert
output
.
shape
[
0
]
==
expected_output_patches
[
i
]
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
@
pytest
.
mark
.
parametrize
(
"spatial_merge_size"
,
[
2
,
4
])
def
test_simple_mrope_vision_model_spatial_merge
(
spatial_merge_size
:
int
):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device
=
current_platform
.
device_type
grid_thw_list
=
[[
1
,
4
,
4
],
[
1
,
6
,
6
]]
# Two images
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
,
device
=
device
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
(
spatial_merge_size
=
spatial_merge_size
).
to
(
device
)
with
torch
.
inference_mode
():
output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Verify output dimensions based on spatial merging
total_patches
=
sum
(
math
.
prod
(
grid_thw
)
for
grid_thw
in
grid_thw_list
)
merge_factor
=
spatial_merge_size
**
2
expected_output_patches
=
total_patches
//
merge_factor
assert
output
.
shape
[
0
]
==
expected_output_patches
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
tests/multimodal/test_utils.py
View file @
babad6e5
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
base64
import
math
import
mimetypes
import
mimetypes
import
os
import
os
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
from
tempfile
import
NamedTemporaryFile
,
TemporaryDirectory
...
@@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
...
@@ -10,22 +9,11 @@ from typing import TYPE_CHECKING, NamedTuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
PIL
import
Image
,
ImageChops
from
PIL
import
Image
,
ImageChops
from
tests.utils
import
multi_gpu_test
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.image
import
convert_image_mode
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.utils
import
(
MediaConnector
,
argsort_mm_positions
,
from
vllm.multimodal.utils
import
MediaConnector
,
argsort_mm_positions
get_load_balance_assignment
,
run_dp_sharded_mrope_vision_model
,
run_dp_sharded_vision_model
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_open_port
,
update_environment_variables
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal.inputs
import
MultiModalPlaceholderDict
from
vllm.multimodal.inputs
import
MultiModalPlaceholderDict
...
@@ -404,415 +392,3 @@ def test_argsort_mm_positions():
...
@@ -404,415 +392,3 @@ def test_argsort_mm_positions():
modality_idxs
=
argsort_mm_positions
(
mm_positions
)
modality_idxs
=
argsort_mm_positions
(
mm_positions
)
assert
modality_idxs
==
expected_modality_idxs
assert
modality_idxs
==
expected_modality_idxs
class
SimpleLinearModel
(
torch
.
nn
.
Module
):
"""A simple linear vision model for testing."""
def
__init__
(
self
,
input_dim
:
int
=
3
*
224
*
224
,
output_dim
:
int
=
32
):
super
().
__init__
()
self
.
flatten
=
torch
.
nn
.
Flatten
()
self
.
linear
=
torch
.
nn
.
Linear
(
input_dim
,
output_dim
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Flatten the input and apply linear transformation
x
=
self
.
flatten
(
x
)
return
self
.
linear
(
x
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
# Single image
4
,
# Small batch
5
,
# Odd batch size (for testing padding)
],
)
def
test_run_dp_sharded_vision_model
(
batch_size
:
int
):
world_size
=
2
# Launch processes
mp
.
spawn
(
run_dp_sharded_vision_model_vs_direct
,
args
=
(
world_size
,
batch_size
,
get_open_port
(),
),
nprocs
=
world_size
,
)
def
run_dp_sharded_vision_model_vs_direct
(
local_rank
:
int
,
world_size
:
int
,
batch_size
:
int
,
master_port
:
int
):
"""
Test that run_dp_sharded_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create a test input tensor
image_input
=
torch
.
randn
(
batch_size
,
3
,
224
,
224
)
# Create a simple linear model
vision_model
=
SimpleLinearModel
()
# Run the model directly on the full input
with
torch
.
inference_mode
():
direct_output
=
vision_model
(
image_input
)
# Run the model through the sharded function
with
torch
.
inference_mode
():
sharded_output
=
run_dp_sharded_vision_model
(
image_input
,
vision_model
)
# Check that the world size is set up correctly
assert
get_tensor_model_parallel_world_size
()
==
world_size
# Check that the outputs have the same shape
assert
direct_output
.
shape
==
sharded_output
.
shape
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
"sizes,num_gpus,expected_shuffle_indices,expected_gpu_sample_counts,"
"expected_grouped_sizes_per_gpu,test_description"
,
[
# Empty input
([],
2
,
[],
[
0
,
0
],
[
0
,
0
],
"empty input"
),
# Fewer samples than GPUs
([
100
,
200
],
4
,
[
1
,
0
],
[
1
,
1
,
0
,
0
],
[
200
,
100
,
0
,
0
],
"fewer samples than GPUs"
),
# Single GPU
([
100
,
200
,
300
],
1
,
[
2
,
1
,
0
],
[
3
],
[
600
],
"single GPU"
),
# Balanced assignment
([
100
,
100
,
100
,
100
],
2
,
[
0
,
2
,
1
,
3
],
[
2
,
2
],
[
200
,
200
],
"balanced assignment"
),
# Unbalanced sizes - this one is trickier since the algorithm is greedy
([
1000
,
100
,
200
,
50
],
2
,
[
0
,
2
,
1
,
3
],
[
1
,
3
],
[
1000
,
350
],
"unbalanced sizes"
),
],
)
def
test_get_load_balance_assignment_cases
(
sizes
,
num_gpus
,
expected_shuffle_indices
,
expected_gpu_sample_counts
,
expected_grouped_sizes_per_gpu
,
test_description
):
"""Test get_load_balance_assignment with various input cases."""
result
=
get_load_balance_assignment
(
sizes
,
num_gpus
=
num_gpus
)
(
shuffle_indices
,
gpu_sample_counts
,
grouped_sizes_per_gpu
)
=
result
# Common assertions for all cases
assert
len
(
shuffle_indices
)
==
len
(
sizes
)
assert
len
(
gpu_sample_counts
)
==
num_gpus
assert
len
(
grouped_sizes_per_gpu
)
==
num_gpus
assert
sum
(
gpu_sample_counts
)
==
len
(
sizes
)
assert
shuffle_indices
==
expected_shuffle_indices
assert
gpu_sample_counts
==
expected_gpu_sample_counts
assert
grouped_sizes_per_gpu
==
expected_grouped_sizes_per_gpu
class
SimpleMRopeVisionModel
(
torch
.
nn
.
Module
):
"""A simple vision model for testing mrope functionality."""
def
__init__
(
self
,
spatial_merge_size
:
int
=
2
,
out_hidden_size
:
int
=
64
):
super
().
__init__
()
self
.
spatial_merge_size
=
spatial_merge_size
self
.
out_hidden_size
=
out_hidden_size
self
.
linear
=
torch
.
nn
.
Linear
(
768
,
out_hidden_size
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
grid_thw_list
:
list
[
list
[
int
]]):
"""Simple forward pass that simulates spatial merging."""
# Apply linear transformation
embeddings
=
self
.
linear
(
pixel_values
)
# Simulate spatial merging by reducing the number of patches
merge_factor
=
self
.
spatial_merge_size
*
self
.
spatial_merge_size
# Group patches and merge spatially
merged_embeddings
=
[]
start_idx
=
0
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
end_idx
=
start_idx
+
num_patches
# Get patches for this image
image_patches
=
embeddings
[
start_idx
:
end_idx
]
# Simulate spatial merging by averaging groups of patches
merged_patches
=
num_patches
//
merge_factor
if
merged_patches
>
0
:
# Reshape and average to simulate merging
reshaped
=
image_patches
[:
merged_patches
*
merge_factor
].
view
(
merged_patches
,
merge_factor
,
-
1
)
merged
=
reshaped
.
mean
(
dim
=
1
)
merged_embeddings
.
append
(
merged
)
start_idx
=
end_idx
if
merged_embeddings
:
return
torch
.
cat
(
merged_embeddings
,
dim
=
0
)
else
:
return
torch
.
empty
((
0
,
self
.
out_hidden_size
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
# Single image
3
,
# Small batch
5
,
# Odd batch size (for testing padding)
],
)
def
test_run_dp_sharded_mrope_vision_model
(
batch_size
:
int
):
world_size
=
2
# Launch processes
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_vs_direct
,
args
=
(
world_size
,
batch_size
,
get_open_port
(),
),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_vs_direct
(
local_rank
:
int
,
world_size
:
int
,
batch_size
:
int
,
master_port
:
int
):
"""
Test that run_dp_sharded_mrope_vision_model produces the same results as
calling the model directly.
"""
# Set random seed for reproducibility
current_platform
.
seed_everything
(
0
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
# initialize distributed
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create test data
grid_thw_list
=
[]
pixel_values_list
=
[]
for
i
in
range
(
batch_size
):
# Varying image sizes for better testing
t
,
h
,
w
=
1
,
4
+
i
,
4
+
i
grid_thw_list
.
append
([
t
,
h
,
w
])
num_patches
=
t
*
h
*
w
# Create random pixel values for this image
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
# Concatenate all pixel values
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
# Create a simple mrope vision model
vision_model
=
SimpleMRopeVisionModel
()
# Run the model directly on the full input (only on rank 0)
if
local_rank
==
0
:
with
torch
.
inference_mode
():
direct_output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Run the model through the sharded function
with
torch
.
inference_mode
():
sharded_output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
sharded_output
=
torch
.
cat
(
sharded_output
,
dim
=
0
)
# Check that the world size is set up correctly
assert
get_tensor_model_parallel_world_size
()
==
world_size
# Compare outputs (only on rank 0)
if
local_rank
==
0
:
# Check that the outputs have the same shape
assert
direct_output
.
shape
==
sharded_output
.
shape
# Check that the outputs are close (they should be identical)
assert
torch
.
allclose
(
direct_output
,
sharded_output
,
rtol
=
1e-5
,
atol
=
1e-5
)
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_run_dp_sharded_mrope_vision_model_empty_input
():
world_size
=
2
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_empty_input_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_empty_input_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with empty input."""
# Set up distributed environment
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create empty inputs
pixel_values
=
torch
.
empty
((
0
,
768
))
grid_thw_list
:
list
[
list
[
int
]]
=
[]
vision_model
=
SimpleMRopeVisionModel
()
# Should handle empty input gracefully
with
torch
.
inference_mode
():
output
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
assert
len
(
output
)
==
0
@
multi_gpu_test
(
num_gpus
=
4
)
def
test_run_dp_sharded_mrope_vision_model_uneven_load
():
world_size
=
4
mp
.
spawn
(
run_dp_sharded_mrope_vision_model_uneven_load_worker
,
args
=
(
world_size
,
get_open_port
()),
nprocs
=
world_size
,
)
def
run_dp_sharded_mrope_vision_model_uneven_load_worker
(
local_rank
:
int
,
world_size
:
int
,
master_port
:
int
):
"""Test run_dp_sharded_mrope_vision_model with uneven load distribution."""
# Set up distributed environment
current_platform
.
seed_everything
(
123
)
device
=
f
"
{
current_platform
.
device_name
}
:
{
local_rank
}
"
current_platform
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
update_environment_variables
({
'RANK'
:
str
(
local_rank
),
'LOCAL_RANK'
:
str
(
local_rank
),
'WORLD_SIZE'
:
str
(
world_size
),
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
str
(
master_port
),
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
# Create images with very different sizes
grid_thw_list
=
[
[
1
,
2
,
2
],
# Small: 4 patches
[
1
,
8
,
8
],
# Large: 64 patches
[
1
,
3
,
3
],
# Medium: 9 patches
]
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
()
# Should handle uneven distribution without errors
with
torch
.
inference_mode
():
output_tuple
=
run_dp_sharded_mrope_vision_model
(
vision_model
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
# Verify output shape is reasonable
merge_factor
=
vision_model
.
spatial_merge_size
**
2
expected_output_patches
=
list
(
math
.
prod
(
grid_thw
)
//
merge_factor
for
grid_thw
in
grid_thw_list
)
for
i
,
output
in
enumerate
(
output_tuple
):
assert
output
.
shape
[
0
]
==
expected_output_patches
[
i
]
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
@
pytest
.
mark
.
parametrize
(
"spatial_merge_size"
,
[
2
,
4
])
def
test_simple_mrope_vision_model_spatial_merge
(
spatial_merge_size
:
int
):
"""Test SimpleMRopeVisionModel with different spatial merge sizes."""
device
=
current_platform
.
device_type
grid_thw_list
=
[[
1
,
4
,
4
],
[
1
,
6
,
6
]]
# Two images
pixel_values_list
=
[]
for
grid_thw
in
grid_thw_list
:
num_patches
=
math
.
prod
(
grid_thw
)
image_pixels
=
torch
.
randn
(
num_patches
,
768
,
device
=
device
)
pixel_values_list
.
append
(
image_pixels
)
pixel_values
=
torch
.
cat
(
pixel_values_list
,
dim
=
0
)
vision_model
=
SimpleMRopeVisionModel
(
spatial_merge_size
=
spatial_merge_size
).
to
(
device
)
with
torch
.
inference_mode
():
output
=
vision_model
(
pixel_values
,
grid_thw_list
)
# Verify output dimensions based on spatial merging
total_patches
=
sum
(
math
.
prod
(
grid_thw
)
for
grid_thw
in
grid_thw_list
)
merge_factor
=
spatial_merge_size
**
2
expected_output_patches
=
total_patches
//
merge_factor
assert
output
.
shape
[
0
]
==
expected_output_patches
assert
output
.
shape
[
1
]
==
vision_model
.
out_hidden_size
vllm/model_executor/models/glm4_1v.py
View file @
babad6e5
...
@@ -69,7 +69,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -69,7 +69,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
...
@@ -83,7 +82,7 @@ from .qwen2_vl import (_create_qwen2vl_field_factory,
...
@@ -83,7 +82,7 @@ from .qwen2_vl import (_create_qwen2vl_field_factory,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
from
.vision
import
get_vit_attn_backend
,
run_dp_sharded_mrope_vision_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/idefics2_vision_model.py
View file @
babad6e5
...
@@ -34,7 +34,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -34,7 +34,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
.vision
import
run_dp_sharded_vision_model
class
Idefics2VisionEmbeddings
(
nn
.
Module
):
class
Idefics2VisionEmbeddings
(
nn
.
Module
):
...
...
vllm/model_executor/models/intern_vit.py
View file @
babad6e5
...
@@ -28,7 +28,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -28,7 +28,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
.vision
import
run_dp_sharded_vision_model
NORM2FN
=
{
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
'rms_norm'
:
RMSNorm
,
...
...
vllm/model_executor/models/kimi_vl.py
View file @
babad6e5
...
@@ -76,13 +76,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -76,13 +76,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
KimiVLConfig
,
MoonViTConfig
from
vllm.transformers_utils.configs
import
KimiVLConfig
,
MoonViTConfig
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
maybe_prefix
from
.vision
import
run_dp_sharded_mrope_vision_model
# For dummy input only
# For dummy input only
...
...
vllm/model_executor/models/mllama4.py
View file @
babad6e5
...
@@ -50,7 +50,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -50,7 +50,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
@@ -58,6 +57,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...
@@ -58,6 +57,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from
.llama4
import
Llama4ForCausalLM
from
.llama4
import
Llama4ForCausalLM
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
run_dp_sharded_vision_model
class
Llama4ImagePatchInputs
(
TensorSchema
):
class
Llama4ImagePatchInputs
(
TensorSchema
):
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
babad6e5
...
@@ -59,7 +59,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -59,7 +59,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
...
@@ -74,7 +73,7 @@ from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
...
@@ -74,7 +73,7 @@ from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
cast_overflow_tensors
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
cast_overflow_tensors
,
init_vllm_registered_model
,
maybe_prefix
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
from
.vision
import
get_vit_attn_backend
,
run_dp_sharded_mrope_vision_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
babad6e5
...
@@ -66,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -66,7 +66,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
...
@@ -78,7 +77,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
...
@@ -78,7 +77,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
from
.vision
import
get_vit_attn_backend
,
run_dp_sharded_mrope_vision_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
babad6e5
...
@@ -83,7 +83,7 @@ from .qwen2_vl import Qwen2VLProcessingInfo
...
@@ -83,7 +83,7 @@ from .qwen2_vl import Qwen2VLProcessingInfo
from
.qwen3
import
Qwen3ForCausalLM
,
Qwen3Model
from
.qwen3
import
Qwen3ForCausalLM
,
Qwen3Model
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
WeightsMapper
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
WeightsMapper
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
from
.vision
import
get_vit_attn_backend
,
run_dp_sharded_mrope_vision_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -1214,8 +1214,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1214,8 +1214,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
else
:
else
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
from
vllm.multimodal.utils
import
(
run_dp_sharded_mrope_vision_model
)
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values
,
pixel_values
,
grid_thw_list
,
grid_thw_list
,
...
@@ -1245,8 +1243,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1245,8 +1243,6 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
self
.
visual
.
dtype
)
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
from
vllm.multimodal.utils
import
(
run_dp_sharded_mrope_vision_model
)
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values_videos
,
pixel_values_videos
,
grid_thw_list
,
grid_thw_list
,
...
...
vllm/model_executor/models/step3_vl.py
View file @
babad6e5
...
@@ -31,7 +31,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -31,7 +31,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
@@ -40,6 +39,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...
@@ -40,6 +39,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
run_dp_sharded_vision_model
class
Step3VLImagePixelInputs
(
TypedDict
):
class
Step3VLImagePixelInputs
(
TypedDict
):
...
...
vllm/model_executor/models/vision.py
View file @
babad6e5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
math
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
from
typing
import
Final
,
Generic
,
Literal
,
Optional
,
Protocol
,
TypeVar
,
Union
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
...
@@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs(
...
@@ -123,3 +128,277 @@ def resolve_visual_encoder_outputs(
if
post_layer_norm
is
not
None
and
uses_last_layer
:
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
def
run_dp_sharded_vision_model
(
image_input
:
torch
.
Tensor
,
vision_model
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
"""Run a vision model with data parallelism (DP) sharding. The function
will shard the input image tensor on the first dimension and run the vision
model
Args:
image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model.
Returns:
torch.Tensor: Output image embeddings
"""
num_chunks
=
image_input
.
shape
[
0
]
mp_world_size
=
get_tensor_model_parallel_world_size
()
num_chunks_per_rank
=
(
num_chunks
+
mp_world_size
-
1
)
//
mp_world_size
num_padded_chunks
=
num_chunks_per_rank
*
mp_world_size
-
num_chunks
pad
=
(
0
,
)
*
(
2
*
(
image_input
.
dim
()
-
1
))
+
(
0
,
num_padded_chunks
)
image_input_padded
=
torch
.
nn
.
functional
.
pad
(
image_input
,
pad
)
rank
=
get_tensor_model_parallel_rank
()
image_input_per_rank
=
image_input_padded
[
rank
*
num_chunks_per_rank
:(
rank
+
1
)
*
num_chunks_per_rank
,
...]
vision_embeddings
=
vision_model
(
image_input_per_rank
)
# Ensure tensor is contiguous before all_gather
vision_embeddings
=
vision_embeddings
.
contiguous
()
vision_embeddings
=
tensor_model_parallel_all_gather
(
vision_embeddings
,
dim
=
0
)
vision_embeddings
=
vision_embeddings
[:
num_chunks
,
...]
return
vision_embeddings
def
get_load_balance_assignment
(
sizes
:
list
[
int
],
num_gpus
:
int
=
2
,
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
"""
Generate load balancing assignment and metadata
for distributing data across GPUs.
The load is determined by the total image sizes,
not the number of images.
Args:
sizes: The size of each image
num_gpus: Number of GPUs to balance across
Returns:
shuffle_indices:
Indices to reorder data for balanced loading
gpu_sample_counts:
Number of samples assigned to each GPU
grouped_sizes_per_gpu:
Total size assigned to each GPU
Example:
```
sizes = [1000, 100, 200, 50]
num_gpus=2
```
"""
n_samples
=
len
(
sizes
)
# Handle edge cases
if
n_samples
==
0
:
return
[],
[
0
]
*
num_gpus
,
[
0
]
*
num_gpus
# Use greedy algorithm - balance by total size, not sample count
gpu_assignments
=
[
list
[
int
]()
for
_
in
range
(
num_gpus
)]
gpu_loads
=
[
0
]
*
num_gpus
# This tracks total SIZE, not sample count
# Sort indices by size (largest first for better load balancing)
# sizes = [1000, 100, 200, 50]
# large_to_small_indices = [0, 2, 1, 3]
large_to_small_indices
=
sorted
(
range
(
n_samples
),
key
=
lambda
i
:
sizes
[
i
],
reverse
=
True
)
for
idx
in
large_to_small_indices
:
# Find GPU with minimum current load (by total size)
min_gpu
=
min
(
range
(
num_gpus
),
key
=
lambda
i
:
gpu_loads
[
i
])
gpu_assignments
[
min_gpu
].
append
(
idx
)
gpu_loads
[
min_gpu
]
+=
sizes
[
idx
]
# Create shuffle indices and counts
shuffle_indices
=
list
[
int
]()
gpu_sample_counts
=
list
[
int
]()
for
gpu_id
in
range
(
num_gpus
):
# GPU_0 = [1000] = [0]
# GPU_1 = [200, 100, 50] = [2, 1, 3]
# shuffle_indices = [0, 2, 1, 3]
shuffle_indices
.
extend
(
gpu_assignments
[
gpu_id
])
# GPU_0 = [1]
# GPU_1 = [3]
# gpu_sample_counts = [1, 3]
gpu_sample_counts
.
append
(
len
(
gpu_assignments
[
gpu_id
]))
return
(
shuffle_indices
,
gpu_sample_counts
,
gpu_loads
)
def
run_dp_sharded_mrope_vision_model
(
vision_model
:
torch
.
nn
.
Module
,
pixel_values
:
torch
.
Tensor
,
grid_thw_list
:
list
[
list
[
int
]],
*
,
rope_type
:
Literal
[
"rope_3d"
,
"rope_2d"
],
)
->
tuple
[
torch
.
Tensor
,
...]:
"""Run a vision model with data parallelism (DP) sharding.
The function will shard the input image tensor on the
first dimension and run the vision model.
This function is used to run the vision model with mrope.
Args:
vision_model (torch.nn.Module): Vision model.
pixel_values (torch.Tensor): Image/Video input tensor.
grid_thw_list: List of grid dimensions for each image
rope_type: Type of rope used in the vision model.
Different rope types have different dimension to do ViT.
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
"rope_2d" for 2D rope (e.g., Kimi-VL)
Returns:
torch.Tensor: Output image embeddings
Example:
```
vision_model.out_hidden_size = 64
vision_model.spatial_merge_size = 2
pixel_values.shape = (1350, channel)
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
tp_size=2
```
"""
tp_size
=
get_tensor_model_parallel_world_size
()
# GPU_0 tp_rank_local = 0
# GPU_1 tp_rank_local = 1
tp_rank_local
=
get_tensor_model_parallel_rank
()
# patches_per_image = [1000, 100, 200, 50]
patches_per_image
=
[
math
.
prod
(
grid_thw
)
for
grid_thw
in
grid_thw_list
]
# patches_per_image = [0, 1000, 1100, 1300, 1350]
cum_patches_per_image
=
[
0
,
*
itertools
.
accumulate
(
patches_per_image
)]
# Get load balancing assignment with all metadata
# image_to_tp_rank = [0, 2, 1, 3]
# gpu_sample_counts = [1, 3]
# grouped_pixel_values_len = [1000, 350]
(
image_to_tp_rank
,
gpu_sample_counts
,
grouped_pixel_values_len
)
=
get_load_balance_assignment
(
patches_per_image
,
tp_size
)
# cu_gpu_sample_counts = [0, 1, 4]
cum_gpu_sample_counts
=
[
0
,
*
itertools
.
accumulate
(
gpu_sample_counts
)]
# GPU_0 image_idxs_local = [0]
# GPU_1 image_idxs_local = [2, 1, 3]
image_idxs_local
=
image_to_tp_rank
[
cum_gpu_sample_counts
[
tp_rank_local
]:
cum_gpu_sample_counts
[
tp_rank_local
+
1
]]
# Get the pixel values for the local images based on the image_idxs_local
if
len
(
image_idxs_local
)
>
0
:
pixel_values_local
=
torch
.
cat
([
pixel_values
[
cum_patches_per_image
[
i
]:
cum_patches_per_image
[
i
+
1
]]
for
i
in
image_idxs_local
])
else
:
# Handle case where this rank has no images
pixel_values_local
=
torch
.
empty
((
0
,
pixel_values
.
shape
[
1
]),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
# embed_dim_reduction_factor = 2 * 2
if
rope_type
==
"rope_2d"
:
embed_dim_reduction_factor
=
(
vision_model
.
merge_kernel_size
[
0
]
*
vision_model
.
merge_kernel_size
[
1
])
else
:
embed_dim_reduction_factor
=
(
vision_model
.
spatial_merge_size
*
vision_model
.
spatial_merge_size
)
# Find the max length across all ranks
# The output embedding of every DP rank has to be
# padded to this length for tensor_model_parallel_all_gather
# to work
max_len_per_rank
=
max
(
grouped_pixel_values_len
)
//
embed_dim_reduction_factor
local_grid_thw_list
=
[
grid_thw_list
[
i
]
for
i
in
image_idxs_local
]
# Run the vision model on the local pixel_values_local
if
rope_type
==
"rope_2d"
:
if
pixel_values_local
.
shape
[
0
]
>
0
:
image_embeds_local
=
vision_model
(
pixel_values_local
,
torch
.
tensor
(
local_grid_thw_list
))
if
isinstance
(
image_embeds_local
,
list
):
image_embeds_local
=
torch
.
cat
(
image_embeds_local
,
dim
=
0
)
else
:
out_dim
=
getattr
(
vision_model
.
config
,
"hidden_size"
,
None
)
image_embeds_local
=
torch
.
empty
(
(
0
,
embed_dim_reduction_factor
,
out_dim
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
else
:
if
pixel_values_local
.
shape
[
0
]
>
0
:
image_embeds_local
=
vision_model
(
pixel_values_local
,
local_grid_thw_list
)
else
:
# Handle empty case
image_embeds_local
=
torch
.
empty
((
0
,
vision_model
.
out_hidden_size
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
# Pad the output based on max_len_per_rank
# for tensor_model_parallel_all_gather to work
current_len
=
image_embeds_local
.
shape
[
0
]
if
current_len
<
max_len_per_rank
:
padding_size
=
max_len_per_rank
-
current_len
if
rope_type
==
"rope_2d"
:
padding
=
torch
.
empty
((
padding_size
,
image_embeds_local
.
shape
[
1
],
image_embeds_local
.
shape
[
2
]),
dtype
=
image_embeds_local
.
dtype
,
device
=
image_embeds_local
.
device
)
else
:
padding
=
torch
.
empty
((
padding_size
,
image_embeds_local
.
shape
[
1
]),
dtype
=
image_embeds_local
.
dtype
,
device
=
image_embeds_local
.
device
)
image_embeds_local_padded
=
torch
.
cat
([
image_embeds_local
,
padding
],
dim
=
0
)
else
:
image_embeds_local_padded
=
image_embeds_local
# Do all_gather to collect embeddings from all ranks
gathered_embeds
=
tensor_model_parallel_all_gather
(
image_embeds_local_padded
,
dim
=
0
)
# Remove padding and reconstruct per-rank embeddings
rank_embeddings
=
list
[
torch
.
Tensor
]()
for
rank
in
range
(
tp_size
):
start_idx
=
rank
*
max_len_per_rank
end_idx
=
start_idx
+
(
grouped_pixel_values_len
[
rank
]
//
embed_dim_reduction_factor
)
rank_embeddings
.
append
(
gathered_embeds
[
start_idx
:
end_idx
])
patches_per_output_image
=
[(
patch_size
//
embed_dim_reduction_factor
)
for
patch_size
in
patches_per_image
]
# Reconstruct embeddings in the original order
original_order_embeddings
=
[
None
]
*
len
(
grid_thw_list
)
current_idx
=
0
for
rank
in
range
(
tp_size
):
count
=
gpu_sample_counts
[
rank
]
if
count
>
0
:
# Get images assigned to this rank in shuffled order
# GPU_0 = image_idxs_local [0]
# GPU_1 = image_idxs_local [2, 1, 3]
rank_images
=
image_to_tp_rank
[
current_idx
:
current_idx
+
count
]
rank_embed
=
rank_embeddings
[
rank
]
# Split rank embeddings back to individual images
embed_start
=
0
for
img_idx
in
rank_images
:
img_patches
=
patches_per_output_image
[
img_idx
]
original_order_embeddings
[
img_idx
]
=
rank_embed
[
embed_start
:
embed_start
+
img_patches
]
embed_start
+=
img_patches
current_idx
+=
count
out_embeddings
=
tuple
(
embed
for
embed
in
original_order_embeddings
if
embed
is
not
None
)
assert
len
(
out_embeddings
)
==
len
(
original_order_embeddings
),
"Found unassigned embeddings"
return
out_embeddings
vllm/multimodal/utils.py
View file @
babad6e5
...
@@ -3,13 +3,11 @@
...
@@ -3,13 +3,11 @@
import
asyncio
import
asyncio
import
atexit
import
atexit
import
itertools
import
math
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
itertools
import
groupby
from
itertools
import
groupby
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
Union
from
urllib.parse
import
ParseResult
,
urlparse
from
urllib.parse
import
ParseResult
,
urlparse
from
urllib.request
import
url2pathname
from
urllib.request
import
url2pathname
...
@@ -21,9 +19,6 @@ from typing_extensions import deprecated
...
@@ -21,9 +19,6 @@ from typing_extensions import deprecated
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.connections
import
HTTPConnection
,
global_http_connection
from
vllm.connections
import
HTTPConnection
,
global_http_connection
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
)
from
.audio
import
AudioMediaIO
from
.audio
import
AudioMediaIO
from
.base
import
MediaIO
from
.base
import
MediaIO
...
@@ -33,12 +28,10 @@ from .video import VideoMediaIO
...
@@ -33,12 +28,10 @@ from .video import VideoMediaIO
_M
=
TypeVar
(
"_M"
)
_M
=
TypeVar
(
"_M"
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.inputs
import
(
BatchedTensorInputs
,
MultiModalKwargs
,
from
.inputs
import
(
BatchedTensorInputs
,
MultiModalKwargsItem
,
MultiModalKwargsItem
,
MultiModalKwargsItems
,
MultiModalKwargsItems
,
MultiModalPlaceholderDict
)
MultiModalPlaceholderDict
)
else
:
else
:
BatchedTensorInputs
=
Any
BatchedTensorInputs
=
Any
MultiModalKwargs
=
Any
MultiModalKwargsItem
=
Any
MultiModalKwargsItem
=
Any
MultiModalKwargsItems
=
Any
MultiModalKwargsItems
=
Any
MultiModalPlaceholderDict
=
Any
MultiModalPlaceholderDict
=
Any
...
@@ -93,7 +86,7 @@ class MediaConnector:
...
@@ -93,7 +86,7 @@ class MediaConnector:
self
,
self
,
url_spec
:
ParseResult
,
url_spec
:
ParseResult
,
media_io
:
MediaIO
[
_M
],
media_io
:
MediaIO
[
_M
],
)
->
_M
:
)
->
_M
:
# type: ignore[type-var]
data_spec
,
data
=
url_spec
.
path
.
split
(
","
,
1
)
data_spec
,
data
=
url_spec
.
path
.
split
(
","
,
1
)
media_type
,
data_type
=
data_spec
.
split
(
";"
,
1
)
media_type
,
data_type
=
data_spec
.
split
(
";"
,
1
)
...
@@ -107,7 +100,7 @@ class MediaConnector:
...
@@ -107,7 +100,7 @@ class MediaConnector:
self
,
self
,
url_spec
:
ParseResult
,
url_spec
:
ParseResult
,
media_io
:
MediaIO
[
_M
],
media_io
:
MediaIO
[
_M
],
)
->
_M
:
)
->
_M
:
# type: ignore[type-var]
allowed_local_media_path
=
self
.
allowed_local_media_path
allowed_local_media_path
=
self
.
allowed_local_media_path
if
allowed_local_media_path
is
None
:
if
allowed_local_media_path
is
None
:
raise
RuntimeError
(
"Cannot load local files without "
raise
RuntimeError
(
"Cannot load local files without "
...
@@ -127,7 +120,7 @@ class MediaConnector:
...
@@ -127,7 +120,7 @@ class MediaConnector:
media_io
:
MediaIO
[
_M
],
media_io
:
MediaIO
[
_M
],
*
,
*
,
fetch_timeout
:
Optional
[
int
]
=
None
,
fetch_timeout
:
Optional
[
int
]
=
None
,
)
->
_M
:
)
->
_M
:
# type: ignore[type-var]
url_spec
=
urlparse
(
url
)
url_spec
=
urlparse
(
url
)
if
url_spec
.
scheme
.
startswith
(
"http"
):
if
url_spec
.
scheme
.
startswith
(
"http"
):
...
@@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality(
...
@@ -434,280 +427,6 @@ def group_mm_kwargs_by_modality(
yield
modality
,
len
(
items_lst
),
mm_kwargs_group
yield
modality
,
len
(
items_lst
),
mm_kwargs_group
def
run_dp_sharded_vision_model
(
image_input
:
torch
.
Tensor
,
vision_model
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
"""Run a vision model with data parallelism (DP) sharding. The function
will shard the input image tensor on the first dimension and run the vision
model
Args:
image_input (torch.Tensor): Image input tensor.
vision_model (torch.nn.Module): Vision model.
Returns:
torch.Tensor: Output image embeddings
"""
num_chunks
=
image_input
.
shape
[
0
]
mp_world_size
=
get_tensor_model_parallel_world_size
()
num_chunks_per_rank
=
(
num_chunks
+
mp_world_size
-
1
)
//
mp_world_size
num_padded_chunks
=
num_chunks_per_rank
*
mp_world_size
-
num_chunks
pad
=
(
0
,
)
*
(
2
*
(
image_input
.
dim
()
-
1
))
+
(
0
,
num_padded_chunks
)
image_input_padded
=
torch
.
nn
.
functional
.
pad
(
image_input
,
pad
)
rank
=
get_tensor_model_parallel_rank
()
image_input_per_rank
=
image_input_padded
[
rank
*
num_chunks_per_rank
:(
rank
+
1
)
*
num_chunks_per_rank
,
...]
vision_embeddings
=
vision_model
(
image_input_per_rank
)
# Ensure tensor is contiguous before all_gather
vision_embeddings
=
vision_embeddings
.
contiguous
()
vision_embeddings
=
tensor_model_parallel_all_gather
(
vision_embeddings
,
dim
=
0
)
vision_embeddings
=
vision_embeddings
[:
num_chunks
,
...]
return
vision_embeddings
def
get_load_balance_assignment
(
sizes
:
list
[
int
],
num_gpus
:
int
=
2
,
)
->
tuple
[
list
[
int
],
list
[
int
],
list
[
int
]]:
"""
Generate load balancing assignment and metadata
for distributing data across GPUs.
The load is determined by the total image sizes,
not the number of images.
Args:
sizes: The size of each image
num_gpus: Number of GPUs to balance across
Returns:
shuffle_indices:
Indices to reorder data for balanced loading
gpu_sample_counts:
Number of samples assigned to each GPU
grouped_sizes_per_gpu:
Total size assigned to each GPU
Example:
```
sizes = [1000, 100, 200, 50]
num_gpus=2
```
"""
n_samples
=
len
(
sizes
)
# Handle edge cases
if
n_samples
==
0
:
return
[],
[
0
]
*
num_gpus
,
[
0
]
*
num_gpus
# Use greedy algorithm - balance by total size, not sample count
gpu_assignments
=
[
list
[
int
]()
for
_
in
range
(
num_gpus
)]
gpu_loads
=
[
0
]
*
num_gpus
# This tracks total SIZE, not sample count
# Sort indices by size (largest first for better load balancing)
# sizes = [1000, 100, 200, 50]
# large_to_small_indices = [0, 2, 1, 3]
large_to_small_indices
=
sorted
(
range
(
n_samples
),
key
=
lambda
i
:
sizes
[
i
],
reverse
=
True
)
for
idx
in
large_to_small_indices
:
# Find GPU with minimum current load (by total size)
min_gpu
=
min
(
range
(
num_gpus
),
key
=
lambda
i
:
gpu_loads
[
i
])
gpu_assignments
[
min_gpu
].
append
(
idx
)
gpu_loads
[
min_gpu
]
+=
sizes
[
idx
]
# Create shuffle indices and counts
shuffle_indices
=
list
[
int
]()
gpu_sample_counts
=
list
[
int
]()
for
gpu_id
in
range
(
num_gpus
):
# GPU_0 = [1000] = [0]
# GPU_1 = [200, 100, 50] = [2, 1, 3]
# shuffle_indices = [0, 2, 1, 3]
shuffle_indices
.
extend
(
gpu_assignments
[
gpu_id
])
# GPU_0 = [1]
# GPU_1 = [3]
# gpu_sample_counts = [1, 3]
gpu_sample_counts
.
append
(
len
(
gpu_assignments
[
gpu_id
]))
return
(
shuffle_indices
,
gpu_sample_counts
,
gpu_loads
)
def
run_dp_sharded_mrope_vision_model
(
vision_model
:
torch
.
nn
.
Module
,
pixel_values
:
torch
.
Tensor
,
grid_thw_list
:
list
[
list
[
int
]],
*
,
rope_type
:
Literal
[
"rope_3d"
,
"rope_2d"
],
)
->
tuple
[
torch
.
Tensor
,
...]:
"""Run a vision model with data parallelism (DP) sharding.
The function will shard the input image tensor on the
first dimension and run the vision model.
This function is used to run the vision model with mrope.
Args:
vision_model (torch.nn.Module): Vision model.
pixel_values (torch.Tensor): Image/Video input tensor.
grid_thw_list: List of grid dimensions for each image
rope_type: Type of rope used in the vision model.
Different rope types have different dimension to do ViT.
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
"rope_2d" for 2D rope (e.g., Kimi-VL)
Returns:
torch.Tensor: Output image embeddings
Example:
```
vision_model.out_hidden_size = 64
vision_model.spatial_merge_size = 2
pixel_values.shape = (1350, channel)
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
tp_size=2
```
"""
tp_size
=
get_tensor_model_parallel_world_size
()
# GPU_0 tp_rank_local = 0
# GPU_1 tp_rank_local = 1
tp_rank_local
=
get_tensor_model_parallel_rank
()
# patches_per_image = [1000, 100, 200, 50]
patches_per_image
=
[
math
.
prod
(
grid_thw
)
for
grid_thw
in
grid_thw_list
]
# patches_per_image = [0, 1000, 1100, 1300, 1350]
cum_patches_per_image
=
[
0
,
*
itertools
.
accumulate
(
patches_per_image
)]
# Get load balancing assignment with all metadata
# image_to_tp_rank = [0, 2, 1, 3]
# gpu_sample_counts = [1, 3]
# grouped_pixel_values_len = [1000, 350]
(
image_to_tp_rank
,
gpu_sample_counts
,
grouped_pixel_values_len
)
=
get_load_balance_assignment
(
patches_per_image
,
tp_size
)
# cu_gpu_sample_counts = [0, 1, 4]
cum_gpu_sample_counts
=
[
0
,
*
itertools
.
accumulate
(
gpu_sample_counts
)]
# GPU_0 image_idxs_local = [0]
# GPU_1 image_idxs_local = [2, 1, 3]
image_idxs_local
=
image_to_tp_rank
[
cum_gpu_sample_counts
[
tp_rank_local
]:
cum_gpu_sample_counts
[
tp_rank_local
+
1
]]
# Get the pixel values for the local images based on the image_idxs_local
if
len
(
image_idxs_local
)
>
0
:
pixel_values_local
=
torch
.
cat
([
pixel_values
[
cum_patches_per_image
[
i
]:
cum_patches_per_image
[
i
+
1
]]
for
i
in
image_idxs_local
])
else
:
# Handle case where this rank has no images
pixel_values_local
=
torch
.
empty
((
0
,
pixel_values
.
shape
[
1
]),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
# embed_dim_reduction_factor = 2 * 2
if
rope_type
==
"rope_2d"
:
embed_dim_reduction_factor
=
(
vision_model
.
merge_kernel_size
[
0
]
*
vision_model
.
merge_kernel_size
[
1
])
else
:
embed_dim_reduction_factor
=
(
vision_model
.
spatial_merge_size
*
vision_model
.
spatial_merge_size
)
# Find the max length across all ranks
# The output embedding of every DP rank has to be
# padded to this length for tensor_model_parallel_all_gather
# to work
max_len_per_rank
=
max
(
grouped_pixel_values_len
)
//
embed_dim_reduction_factor
local_grid_thw_list
=
[
grid_thw_list
[
i
]
for
i
in
image_idxs_local
]
# Run the vision model on the local pixel_values_local
if
rope_type
==
"rope_2d"
:
if
pixel_values_local
.
shape
[
0
]
>
0
:
image_embeds_local
=
vision_model
(
pixel_values_local
,
torch
.
tensor
(
local_grid_thw_list
))
if
isinstance
(
image_embeds_local
,
list
):
image_embeds_local
=
torch
.
cat
(
image_embeds_local
,
dim
=
0
)
else
:
out_dim
=
getattr
(
vision_model
.
config
,
"hidden_size"
,
None
)
image_embeds_local
=
torch
.
empty
(
(
0
,
embed_dim_reduction_factor
,
out_dim
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
else
:
if
pixel_values_local
.
shape
[
0
]
>
0
:
image_embeds_local
=
vision_model
(
pixel_values_local
,
local_grid_thw_list
)
else
:
# Handle empty case
image_embeds_local
=
torch
.
empty
((
0
,
vision_model
.
out_hidden_size
),
device
=
pixel_values
.
device
,
dtype
=
pixel_values
.
dtype
)
# Pad the output based on max_len_per_rank
# for tensor_model_parallel_all_gather to work
current_len
=
image_embeds_local
.
shape
[
0
]
if
current_len
<
max_len_per_rank
:
padding_size
=
max_len_per_rank
-
current_len
if
rope_type
==
"rope_2d"
:
padding
=
torch
.
empty
((
padding_size
,
image_embeds_local
.
shape
[
1
],
image_embeds_local
.
shape
[
2
]),
dtype
=
image_embeds_local
.
dtype
,
device
=
image_embeds_local
.
device
)
else
:
padding
=
torch
.
empty
((
padding_size
,
image_embeds_local
.
shape
[
1
]),
dtype
=
image_embeds_local
.
dtype
,
device
=
image_embeds_local
.
device
)
image_embeds_local_padded
=
torch
.
cat
([
image_embeds_local
,
padding
],
dim
=
0
)
else
:
image_embeds_local_padded
=
image_embeds_local
# Do all_gather to collect embeddings from all ranks
gathered_embeds
=
tensor_model_parallel_all_gather
(
image_embeds_local_padded
,
dim
=
0
)
# Remove padding and reconstruct per-rank embeddings
rank_embeddings
=
list
[
torch
.
Tensor
]()
for
rank
in
range
(
tp_size
):
start_idx
=
rank
*
max_len_per_rank
end_idx
=
start_idx
+
(
grouped_pixel_values_len
[
rank
]
//
embed_dim_reduction_factor
)
rank_embeddings
.
append
(
gathered_embeds
[
start_idx
:
end_idx
])
patches_per_output_image
=
[(
patch_size
//
embed_dim_reduction_factor
)
for
patch_size
in
patches_per_image
]
# Reconstruct embeddings in the original order
original_order_embeddings
=
[
None
]
*
len
(
grid_thw_list
)
current_idx
=
0
for
rank
in
range
(
tp_size
):
count
=
gpu_sample_counts
[
rank
]
if
count
>
0
:
# Get images assigned to this rank in shuffled order
# GPU_0 = image_idxs_local [0]
# GPU_1 = image_idxs_local [2, 1, 3]
rank_images
=
image_to_tp_rank
[
current_idx
:
current_idx
+
count
]
rank_embed
=
rank_embeddings
[
rank
]
# Split rank embeddings back to individual images
embed_start
=
0
for
img_idx
in
rank_images
:
img_patches
=
patches_per_output_image
[
img_idx
]
original_order_embeddings
[
img_idx
]
=
rank_embed
[
embed_start
:
embed_start
+
img_patches
]
embed_start
+=
img_patches
current_idx
+=
count
out_embeddings
=
tuple
(
embed
for
embed
in
original_order_embeddings
if
embed
is
not
None
)
assert
len
(
out_embeddings
)
==
len
(
original_order_embeddings
),
"Found unassigned embeddings"
return
out_embeddings
def
fetch_audio
(
def
fetch_audio
(
audio_url
:
str
,
audio_url
:
str
,
audio_io_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
audio_io_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment