Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c713a97
Unverified
Commit
3c713a97
authored
Sep 20, 2025
by
Isotr0py
Committed by
GitHub
Sep 20, 2025
Browse files
[Model] Cleanup InternViT's data parallel implementation (#25306)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
bf8b26ca
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
121 deletions
+37
-121
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+37
-121
No files found.
vllm/model_executor/models/intern_vit.py
View file @
3c713a97
...
...
@@ -25,7 +25,6 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -164,15 +163,6 @@ class InternParallelAttention(nn.Module):
self
.
tp_size
)
self
.
scale
=
self
.
head_dim
**-
0.5
if
use_data_parallel
:
self
.
qkv
=
ReplicatedLinear
(
self
.
embed_dim
,
3
*
self
.
head_dim
*
self
.
num_heads
,
bias
=
config
.
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
,
)
else
:
self
.
qkv
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
...
...
@@ -180,6 +170,7 @@ class InternParallelAttention(nn.Module):
bias
=
config
.
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv"
,
disable_tp
=
use_data_parallel
,
)
self
.
qk_normalization
=
config
.
qk_normalization
...
...
@@ -192,19 +183,12 @@ class InternParallelAttention(nn.Module):
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
if
use_data_parallel
:
self
.
proj
=
ReplicatedLinear
(
self
.
dummy_dim
,
self
.
embed_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
else
:
self
.
proj
=
RowParallelLinear
(
self
.
dummy_dim
,
self
.
embed_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
disable_tp
=
use_data_parallel
,
)
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads_per_partition
,
...
...
@@ -236,72 +220,6 @@ class InternParallelAttention(nn.Module):
return
out
class
InternSdpaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
*
,
num_dummy_heads
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
'embed_dim must be divisible by num_heads '
f
'(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:'
f
'
{
self
.
num_heads
}
).'
)
# Additional dummy heads are used to enable TP for common GPU counts.
self
.
dummy_dim
=
(
num_dummy_heads
+
self
.
num_heads
)
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
self
.
embed_dim
,
3
*
self
.
dummy_dim
,
bias
=
config
.
qkv_bias
)
self
.
qk_normalization
=
config
.
qk_normalization
if
self
.
qk_normalization
:
self
.
q_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
k_norm
=
RMSNorm
(
self
.
dummy_dim
,
eps
=
config
.
layer_norm_eps
,
var_hidden_size
=
self
.
embed_dim
)
self
.
proj
=
nn
.
Linear
(
self
.
dummy_dim
,
self
.
embed_dim
)
# Use unified MultiHeadAttention with automatic backend selection
self
.
attn
=
MultiHeadAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scale
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=-
1
)
q
=
q
.
view
(
B
,
N
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
B
,
N
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
B
,
N
,
self
.
num_heads
,
self
.
head_dim
)
if
self
.
qk_normalization
:
B_
,
N_
,
H_
,
D_
=
q
.
shape
q
=
self
.
q_norm
(
q
.
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
)
k
=
self
.
k_norm
(
k
.
flatten
(
-
2
,
-
1
)).
view
(
B_
,
N_
,
H_
,
D_
)
# Use unified MultiHeadAttention with automatic backend selection
x
=
self
.
attn
(
q
,
k
,
v
)
x
=
self
.
proj
(
x
)
return
x
class
InternMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -315,20 +233,18 @@ class InternMLP(nn.Module):
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
cls_fc1
=
(
ReplicatedLinear
if
use_data_parallel
else
ColumnParallelLinear
)
self
.
fc1
=
cls_fc1
(
config
.
hidden_size
,
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
cls_fc2
=
(
ReplicatedLinear
if
use_data_parallel
else
RowParallelLinear
)
self
.
fc2
=
cls_fc2
(
config
.
intermediate_size
,
prefix
=
f
"
{
prefix
}
.fc1"
,
disable_tp
=
use_data_parallel
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
prefix
=
f
"
{
prefix
}
.fc2"
,
disable_tp
=
use_data_parallel
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
@@ -385,20 +301,20 @@ class InternVisionEncoderLayer(nn.Module):
use_data_parallel
:
bool
=
False
,
):
# fallback to sdpa attention if tp unavailable
# tp_size = get_tensor_model_parallel_world_size()
tp_size
=
(
1
if
use_data_parallel
else
get_tensor_model_parallel_world_size
())
num_heads
=
config
.
num_attention_heads
if
(
num_heads
+
num_dummy_heads
)
%
tp_size
==
0
:
# if the number of heads is not divisible by tp_size,
# we also disable Attention's TP
use_data_parallel
=
(
use_data_parallel
or
(
num_heads
+
num_dummy_heads
)
%
tp_size
!=
0
)
return
InternParallelAttention
(
config
,
quant_config
=
quant_config
,
num_dummy_heads
=
num_dummy_heads
,
prefix
=
prefix
,
use_data_parallel
=
use_data_parallel
)
return
InternSdpaAttention
(
config
,
num_dummy_heads
=
num_dummy_heads
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment