Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47c71262
Unverified
Commit
47c71262
authored
Mar 21, 2025
by
Isotr0py
Committed by
GitHub
Mar 21, 2025
Browse files
[Misc] Add attention mask pre-computation optimization back to Qwen2.5-VL (#15273)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
a989ca2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
16 deletions
+35
-16
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+23
-10
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+12
-6
No files found.
vllm/model_executor/models/qwen2_5_vl.py
View file @
47c71262
...
@@ -608,6 +608,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -608,6 +608,17 @@ class Qwen2_5_VisionTransformer(nn.Module):
window_index
=
torch
.
cat
(
window_index
,
dim
=
0
)
window_index
=
torch
.
cat
(
window_index
,
dim
=
0
)
return
window_index
,
cu_window_seqlens
return
window_index
,
cu_window_seqlens
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
max_seqlen
,
seqlens
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -645,25 +656,27 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -645,25 +656,27 @@ class Qwen2_5_VisionTransformer(nn.Module):
# transformers
# transformers
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
max_seqlen
=
None
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
seqlens
=
None
max_seqlen_full
,
seqlens_full
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen_window
,
seqlens_window
=
self
.
compute_attn_mask_seqlen
(
cu_window_seqlens
)
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
if
layer_num
in
self
.
fullatt_block_indexes
:
if
layer_num
in
self
.
fullatt_block_indexes
:
cu_seqlens_now
=
cu_seqlens
cu_seqlens_now
=
cu_seqlens
max_seqlen_now
=
max_seqlen_full
seqlens_now
=
seqlens_full
else
:
else
:
cu_seqlens_now
=
cu_window_seqlens
cu_seqlens_now
=
cu_window_seqlens
# pre-compute cu_seqlens for window attn
max_seqlen_now
=
max_seqlen_window
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
seqlens_now
=
seqlens_window
max_seqlen
=
(
cu_seqlens_now
[
1
:]
-
cu_seqlens_now
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens_now
[
1
:]
-
cu_seqlens_now
[:
-
1
]).
tolist
()
hidden_states
=
blk
(
hidden_states
=
blk
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens_now
,
cu_seqlens
=
cu_seqlens_now
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
_now
,
seqlens
=
seqlens
,
seqlens
=
seqlens
_now
,
)
)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# For Qwen2.5-VL-3B, float16 will overflow at last block
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
47c71262
...
@@ -617,6 +617,16 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -617,6 +617,16 @@ class Qwen2VisionTransformer(nn.Module):
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
return
rotary_pos_emb
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
max_seqlen
,
seqlens
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -638,12 +648,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -638,12 +648,8 @@ class Qwen2VisionTransformer(nn.Module):
# transformers
# transformers
x
=
x
.
unsqueeze
(
1
)
x
=
x
.
unsqueeze
(
1
)
max_seqlen
=
None
# pre-compute seqlens for attn mask to reduce cuMemcpy operations
seqlens
=
None
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
for
blk
in
self
.
blocks
:
for
blk
in
self
.
blocks
:
x
=
blk
(
x
=
blk
(
x
,
x
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment