Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98f47f2a
Unverified
Commit
98f47f2a
authored
Nov 28, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 28, 2024
Browse files
[V1] Optimize the CPU overheads in FlashAttention custom op (#10733)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
8c1e77fb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+9
-8
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
98f47f2a
...
...
@@ -135,6 +135,13 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
# overheads from the non-CUDA-graph regions.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
output
=
torch
.
empty_like
(
query
)
torch
.
ops
.
vllm
.
unified_v1_flash_attention
(
output
,
...
...
@@ -153,7 +160,7 @@ class FlashAttentionImpl(AttentionImpl):
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
)
return
output
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
unified_v1_flash_attention
(
...
...
@@ -184,11 +191,6 @@ def unified_v1_flash_attention(
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
# Reshape the input keys and values and store them in the cache.
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
...
...
@@ -218,8 +220,7 @@ def unified_v1_flash_attention(
block_table
=
attn_metadata
.
block_table
,
softcap
=
logits_soft_cap
,
)
attn_output
=
attn_output
.
view
(
num_actual_tokens
,
-
1
)
# TODO(woosuk): Optimize this.
# TODO(woosuk): Remove this unnecessary copy.
output
[:
num_actual_tokens
].
copy_
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment