Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
70b808fe
Unverified
Commit
70b808fe
authored
Mar 11, 2025
by
yexin(叶鑫)
Committed by
GitHub
Mar 11, 2025
Browse files
[Perf]:Optimize qwen2-vl to reduce cudaMemcpyAsync (#14377)
Signed-off-by:
cynthieye
<
987073381@qq.com
>
parent
63d635d1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
24 deletions
+70
-24
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+33
-12
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+37
-12
No files found.
vllm/model_executor/models/qwen2_5_vl.py
View file @
70b808fe
...
@@ -259,6 +259,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -259,6 +259,8 @@ class Qwen2_5_VisionAttention(nn.Module):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
Optional
[
int
]
=
None
,
# Only used for Flash Attention
seqlens
:
Optional
[
list
[
int
]]
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, head * 3 * head_dim]
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
x
,
_
=
self
.
qkv
(
x
)
...
@@ -285,7 +287,6 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -285,7 +287,6 @@ class Qwen2_5_VisionAttention(nn.Module):
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
output
=
flash_attn_varlen_func
(
q
,
output
=
flash_attn_varlen_func
(
q
,
k
,
k
,
v
,
v
,
...
@@ -321,7 +322,6 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -321,7 +322,6 @@ class Qwen2_5_VisionAttention(nn.Module):
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
kv_seqlen
=
None
,
device
=
q
.
device
)
device
=
q
.
device
)
...
@@ -364,11 +364,20 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -364,11 +364,20 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
def
forward
(
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
Optional
[
int
]
=
None
,
# Only used for Flash Attention
seqlens
:
Optional
[
list
[
int
]]
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
return
x
...
@@ -528,6 +537,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -528,6 +537,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
@@ -633,14 +643,25 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -633,14 +643,25 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# transformers
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
max_seqlen
=
None
seqlens
=
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
if
layer_num
in
self
.
fullatt_block_indexes
:
if
layer_num
in
self
.
fullatt_block_indexes
:
cu_seqlens_now
=
cu_seqlens
cu_seqlens_now
=
cu_seqlens
else
:
else
:
cu_seqlens_now
=
cu_window_seqlens
cu_seqlens_now
=
cu_window_seqlens
hidden_states
=
blk
(
hidden_states
,
hidden_states
=
blk
(
hidden_states
,
cu_seqlens
=
cu_seqlens_now
,
cu_seqlens
=
cu_seqlens_now
,
rotary_pos_emb
=
rotary_pos_emb
)
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
# for long visual tokens sequences.
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
70b808fe
...
@@ -307,6 +307,8 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -307,6 +307,8 @@ class Qwen2VisionAttention(nn.Module):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
Optional
[
int
]
=
None
,
# Only used for Flash Attention
seqlens
:
Optional
[
list
[
int
]]
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, 3 * head * head_dim]
# [s, b, c] --> [s, b, 3 * head * head_dim]
...
@@ -329,7 +331,6 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -329,7 +331,6 @@ class Qwen2VisionAttention(nn.Module):
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
output
=
flash_attn_varlen_func
(
q
,
output
=
flash_attn_varlen_func
(
q
,
k
,
k
,
v
,
v
,
...
@@ -365,7 +366,6 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -365,7 +366,6 @@ class Qwen2VisionAttention(nn.Module):
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
kv_seqlen
=
None
,
device
=
q
.
device
)
device
=
q
.
device
)
...
@@ -409,11 +409,22 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -409,11 +409,22 @@ class Qwen2VisionBlock(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
def
forward
(
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
Optional
[
int
]
=
None
,
# Only used for Flash Attention
seqlens
:
Optional
[
list
[
int
]]
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
return
x
...
@@ -570,6 +581,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -570,6 +581,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
@@ -624,8 +636,21 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -624,8 +636,21 @@ class Qwen2VisionTransformer(nn.Module):
# transformers
# transformers
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
max_seqlen
=
None
seqlens
=
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
for
blk
in
self
.
blocks
:
for
blk
in
self
.
blocks
:
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
# adapter
# adapter
x
=
self
.
merger
(
x
)
x
=
self
.
merger
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment