Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c98be0a2
Unverified
Commit
c98be0a2
authored
Sep 23, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 23, 2025
Browse files
[Model] Enable DP for ViT in Qwen2-VL (#25445)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
5774b0a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
19 deletions
+59
-19
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+59
-19
No files found.
vllm/model_executor/models/qwen2_vl.py
View file @
c98be0a2
...
@@ -66,6 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -66,6 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_mrope_vision_model
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
...
@@ -217,17 +218,20 @@ class Qwen2VisionMLP(nn.Module):
...
@@ -217,17 +218,20 @@ class Qwen2VisionMLP(nn.Module):
act_layer
:
type
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
type
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
hidden_features
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
prefix
=
f
"
{
prefix
}
.fc1"
,
disable_tp
=
use_data_parallel
)
self
.
act
=
act_layer
()
self
.
act
=
act_layer
()
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
in_features
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
prefix
=
f
"
{
prefix
}
.fc2"
,
disable_tp
=
use_data_parallel
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel
,
_
=
self
.
fc1
(
x
)
x_parallel
,
_
=
self
.
fc1
(
x
)
...
@@ -293,25 +297,28 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -293,25 +297,28 @@ class Qwen2VisionAttention(nn.Module):
projection_size
:
int
,
projection_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Per attention head and per partition values.
# Per attention head and per partition values.
world
_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp
_size
=
(
1
if
use_data_parallel
else
self
.
tp_size
=
world_size
parallel_state
.
get_tensor_model_parallel_
world_size
())
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
world
_size
)
num_heads
,
self
.
tp
_size
)
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
output_size
=
3
*
projection_size
,
output_size
=
3
*
projection_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
)
prefix
=
f
"
{
prefix
}
.qkv"
,
disable_tp
=
use_data_parallel
)
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
output_size
=
embed_dim
,
output_size
=
embed_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
)
prefix
=
f
"
{
prefix
}
.proj"
,
disable_tp
=
use_data_parallel
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
attn_backend
=
get_vit_attn_backend
(
...
@@ -453,6 +460,7 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -453,6 +460,7 @@ class Qwen2VisionBlock(nn.Module):
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
...
@@ -465,12 +473,14 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -465,12 +473,14 @@ class Qwen2VisionBlock(nn.Module):
num_heads
=
num_heads
,
num_heads
=
num_heads
,
projection_size
=
dim
,
projection_size
=
dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
)
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
mlp_hidden_dim
,
mlp_hidden_dim
,
act_layer
=
act_layer
,
act_layer
=
act_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
,
use_data_parallel
=
use_data_parallel
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -531,6 +541,7 @@ class Qwen2VisionPatchMerger(nn.Module):
...
@@ -531,6 +541,7 @@ class Qwen2VisionPatchMerger(nn.Module):
spatial_merge_size
:
int
=
2
,
spatial_merge_size
:
int
=
2
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
...
@@ -542,13 +553,15 @@ class Qwen2VisionPatchMerger(nn.Module):
...
@@ -542,13 +553,15 @@ class Qwen2VisionPatchMerger(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.0"
),
prefix
=
f
"
{
prefix
}
.mlp.0"
,
disable_tp
=
use_data_parallel
),
nn
.
GELU
(),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
d_model
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.2"
),
prefix
=
f
"
{
prefix
}
.mlp.2"
,
disable_tp
=
use_data_parallel
),
])
])
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -600,6 +613,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -600,6 +613,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_eps
:
float
=
1e-6
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -613,6 +627,9 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -613,6 +627,9 @@ class Qwen2VisionTransformer(nn.Module):
num_heads
=
vision_config
.
num_heads
num_heads
=
vision_config
.
num_heads
mlp_ratio
=
vision_config
.
mlp_ratio
mlp_ratio
=
vision_config
.
mlp_ratio
self
.
use_data_parallel
=
use_data_parallel
self
.
out_hidden_size
=
vision_config
.
hidden_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
spatial_merge_size
=
spatial_merge_size
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
...
@@ -634,7 +651,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -634,7 +651,8 @@ class Qwen2VisionTransformer(nn.Module):
mlp_ratio
=
mlp_ratio
,
mlp_ratio
=
mlp_ratio
,
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
)
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
)
for
layer_idx
in
range
(
depth
)
for
layer_idx
in
range
(
depth
)
])
])
self
.
merger
=
Qwen2VisionPatchMerger
(
self
.
merger
=
Qwen2VisionPatchMerger
(
...
@@ -643,6 +661,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -643,6 +661,7 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
use_data_parallel
=
use_data_parallel
,
)
)
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
...
@@ -659,8 +678,9 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -659,8 +678,9 @@ class Qwen2VisionTransformer(nn.Module):
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
list
[
list
[
int
]]
)
->
torch
.
Tensor
:
pos_ids
=
[]
pos_ids
=
[]
max_grid_size
=
0
for
t
,
h
,
w
in
grid_thw
:
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
...
@@ -678,8 +698,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -678,8 +698,8 @@ class Qwen2VisionTransformer(nn.Module):
).
permute
(
0
,
2
,
1
,
3
).
flatten
()
).
permute
(
0
,
2
,
1
,
3
).
flatten
()
pos_ids
.
append
(
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
max_grid_size
=
max
(
max_grid_size
,
h
,
w
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
return
rotary_pos_emb
...
@@ -698,7 +718,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -698,7 +718,7 @@ class Qwen2VisionTransformer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
,
grid_thw
:
list
[
list
[
int
]]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# patchify
# patchify
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -708,8 +728,9 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -708,8 +728,9 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw_
=
torch
.
tensor
(
grid_thw
)
grid_thw
[:,
0
]).
cumsum
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw_
[:,
1
]
*
grid_thw_
[:,
2
],
grid_thw_
[:,
0
]).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dim
=
0
,
dtype
=
torch
.
int32
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
...
@@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1112,6 +1133,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model."
:
"language_model.model."
,
"model."
:
"language_model.model."
,
})
})
supports_encoder_tp_data
=
True
def
get_mrope_input_positions
(
def
get_mrope_input_positions
(
self
,
self
,
input_tokens
:
list
[
int
],
input_tokens
:
list
[
int
],
...
@@ -1239,6 +1262,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1239,6 +1262,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
...
@@ -1249,6 +1273,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1249,6 +1273,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
)
)
else
:
else
:
self
.
visual
=
None
self
.
visual
=
None
...
@@ -1357,7 +1382,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1357,7 +1382,15 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds
=
image_input
[
"image_embeds"
]
image_embeds
=
image_input
[
"image_embeds"
]
else
:
else
:
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw
)
if
self
.
use_data_parallel
:
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
else
:
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw_list
)
# Split concatenated embeddings for each image item.
# Split concatenated embeddings for each image item.
merge_size
=
self
.
visual
.
spatial_merge_size
merge_size
=
self
.
visual
.
spatial_merge_size
...
@@ -1377,7 +1410,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1377,7 +1410,14 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
video_embeds
=
video_input
[
"video_embeds"
]
video_embeds
=
video_input
[
"video_embeds"
]
else
:
else
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
]
pixel_values_videos
=
video_input
[
"pixel_values_videos"
]
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
grid_thw
)
if
self
.
use_data_parallel
:
return
run_dp_sharded_mrope_vision_model
(
self
.
visual
,
pixel_values_videos
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
else
:
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
grid_thw_list
)
# Split concatenated embeddings for each video item.
# Split concatenated embeddings for each video item.
merge_size
=
self
.
visual
.
spatial_merge_size
merge_size
=
self
.
visual
.
spatial_merge_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment