Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ab278ca
Unverified
Commit
0ab278ca
authored
Jun 03, 2024
by
Antoni Baum
Committed by
GitHub
Jun 03, 2024
Browse files
[Core] Remove unnecessary copies in flash attn backend (#5138)
parent
7a64d24a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
7 deletions
+8
-7
requirements-cuda.txt
requirements-cuda.txt
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+7
-6
No files found.
requirements-cuda.txt
View file @
0ab278ca
...
...
@@ -6,4 +6,4 @@ ray >= 2.9
nvidia-ml-py # for pynvml package
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.
8.post2
# Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.
9
# Requires PyTorch 2.3.0
vllm/attention/backends/flash_attn.py
View file @
0ab278ca
...
...
@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[:
num_prefill_tokens
],
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
out
=
output
[:
num_prefill_tokens
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
...
...
@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
).
squeeze
(
1
)
out
=
output
[
num_prefill_tokens
:].
unsqueeze
(
1
),
)
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment