Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c8dd12e
Unverified
Commit
4c8dd12e
authored
Feb 08, 2025
by
Isotr0py
Committed by
GitHub
Feb 08, 2025
Browse files
[Misc] Add qwen2.5-vl BNB support (#12944)
parent
256a2d29
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
30 deletions
+29
-30
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+29
-30
No files found.
vllm/model_executor/models/qwen2_5_vl.py
View file @
4c8dd12e
...
@@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
...
@@ -40,7 +40,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
...
@@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -207,11 +207,12 @@ class Qwen2_5_VisionAttention(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
world
_size
)
num_heads
,
self
.
tp
_size
)
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
output_size
=
3
*
projection_size
,
output_size
=
3
*
projection_size
,
...
@@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -231,6 +232,29 @@ class Qwen2_5_VisionAttention(nn.Module):
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
)
)
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
if
self
.
tp_size
>
1
:
qkv
=
tensor_model_parallel_all_gather
(
qkv
)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
2
)
# 3 * [s, b, head * head_dim]
if
self
.
tp_size
>
1
:
splitter
=
partial
(
dist_utils
.
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
v
=
splitter
(
v
)[
self
.
tp_rank
]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape
=
(
seq_len
,
bs
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
q
,
k
,
v
=
(
x
.
view
(
*
new_shape
)
for
x
in
(
q
,
k
,
v
))
return
q
,
k
,
v
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -240,15 +264,8 @@ class Qwen2_5_VisionAttention(nn.Module):
# [s, b, c] --> [s, b, head * 3 * head_dim]
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
x
,
_
=
self
.
qkv
(
x
)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
q
,
k
,
v
=
self
.
split_qkv
(
x
)
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
x
=
x
.
view
(
*
new_x_shape
)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q
,
k
,
v
=
dist_utils
.
split_tensor_along_last_dim
(
x
,
3
)
batch_size
=
q
.
shape
[
1
]
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
...
@@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -665,24 +682,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
if
name
.
endswith
(
"qkv.weight"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
name
.
endswith
(
"qkv.bias"
):
visual_num_heads
=
self
.
num_heads
visual_embed_dim
=
self
.
hidden_size
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment