Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb2b5126
Unverified
Commit
bb2b5126
authored
Sep 12, 2025
by
Isotr0py
Committed by
GitHub
Sep 11, 2025
Browse files
[VLM] Migrate remain DP-supported ViT models to use `disable_tp` (#24363)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
361ae27f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
106 deletions
+54
-106
vllm/model_executor/models/idefics2_vision_model.py
vllm/model_executor/models/idefics2_vision_model.py
+20
-65
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+13
-19
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+21
-22
No files found.
vllm/model_executor/models/idefics2_vision_model.py
View file @
bb2b5126
...
...
@@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -139,29 +138,13 @@ class Idefics2VisionAttention(nn.Module):
assert
self
.
num_heads
%
tp_size
==
0
self
.
num_heads_per_partition
=
self
.
num_heads
//
tp_size
if
use_data_parallel
:
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
qkv_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
3
*
self
.
q_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
ReplicatedLinear
(
self
.
embed_dim
,
self
.
embed_dim
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
else
:
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
disable_tp
=
use_data_parallel
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
embed_dim
,
...
...
@@ -169,6 +152,7 @@ class Idefics2VisionAttention(nn.Module):
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
disable_tp
=
use_data_parallel
,
)
# Use unified MultiHeadAttention with Flash Attention support
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
...
...
@@ -201,23 +185,21 @@ class Idefics2VisionMLP(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
cls_fc1
=
(
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
)
self
.
fc1
=
cls_fc1
(
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
disable_tp
=
use_data_parallel
,
)
cls_fc2
=
(
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
)
self
.
fc2
=
cls_fc2
(
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
disable_tp
=
use_data_parallel
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -389,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
def
_consolidate_qkv_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
)
->
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]:
qkv_idx_mappings
=
{
".self_attn.q_proj"
:
0
,
".self_attn.k_proj"
:
1
,
".self_attn.v_proj"
:
2
,
}
qkv_weights
=
{}
for
name
,
loaded_weight
in
weights
:
for
weight_name
,
idx
in
qkv_idx_mappings
.
items
():
if
weight_name
not
in
name
:
continue
new_name
=
name
.
replace
(
weight_name
,
".self_attn.qkv_proj"
)
if
new_name
not
in
qkv_weights
:
qkv_weights
[
new_name
]
=
[
None
]
*
3
qkv_weights
[
new_name
][
idx
]
=
loaded_weight
break
else
:
yield
name
,
loaded_weight
for
key
,
weight
in
qkv_weights
.
items
():
qkv_weight
=
torch
.
cat
(
weight
,
dim
=
0
)
yield
key
,
qkv_weight
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -425,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module):
loaded_params
:
set
[
str
]
=
set
()
layer_count
=
len
(
self
.
encoder
.
layers
)
if
self
.
use_data_parallel
:
weights
=
self
.
_consolidate_qkv_weights
(
weights
)
for
name
,
loaded_weight
in
weights
:
# skip pooling header
if
name
.
startswith
(
"head."
):
...
...
vllm/model_executor/models/mllama4.py
View file @
bb2b5126
...
...
@@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module):
use_data_parallel
:
bool
=
False
,
):
super
().
__init__
()
cls_fc1
=
(
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
)
self
.
fc1
=
cls_fc1
(
self
.
fc1
=
ColumnParallelLinear
(
input_size
=
input_size
,
output_size
=
intermediate_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
,
disable_tp
=
use_data_parallel
,
)
cls_fc2
=
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
self
.
fc2
=
cls_fc2
(
self
.
fc2
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
output_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
,
disable_tp
=
use_data_parallel
,
)
self
.
activation_fn
=
nn
.
GELU
()
self
.
output_activation
=
output_activation
...
...
@@ -419,20 +418,15 @@ class Llama4UnfoldConvolution(nn.Module):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
config
.
patch_size
)
params
=
{
"input_size"
:
config
.
num_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
"output_size"
:
config
.
hidden_size
,
"bias"
:
False
,
"quant_config"
:
quant_config
,
"prefix"
:
f
"
{
prefix
}
.linear"
,
}
if
use_data_parallel
:
cls
=
ReplicatedLinear
else
:
cls
=
ColumnParallelLinear
params
[
"gather_output"
]
=
True
self
.
linear
=
cls
(
**
params
)
self
.
linear
=
ColumnParallelLinear
(
input_size
=
config
.
num_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
output_size
=
config
.
hidden_size
,
bias
=
False
,
gather_output
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear"
,
disable_tp
=
use_data_parallel
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
unfold
(
hidden_states
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
bb2b5126
...
...
@@ -49,7 +49,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -510,32 +509,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
ln_q
=
norm_layer
(
context_dim
)
cls_fc1
=
(
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
)
cls_fc2
=
(
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
)
self
.
mlp
=
nn
.
ModuleList
([
cls_fc1
(
self
.
hidden_size
,
self
.
mlp
=
nn
.
Sequential
(
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.0"
),
prefix
=
f
"
{
prefix
}
.mlp.0"
,
return_bias
=
False
,
disable_tp
=
use_data_parallel
,
),
nn
.
GELU
(),
cls_fc2
(
self
.
hidden_size
,
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.2"
),
])
prefix
=
f
"
{
prefix
}
.mlp.2"
,
return_bias
=
False
,
disable_tp
=
use_data_parallel
,
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
ln_q
(
x
)
x
=
x
.
view
(
-
1
,
self
.
hidden_size
)
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
out
=
self
.
mlp
(
x
)
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment