Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
073a4bd1
Unverified
Commit
073a4bd1
authored
Dec 01, 2024
by
Woosuk Kwon
Committed by
GitHub
Dec 01, 2024
Browse files
[Kernel] Use `out` arg in flash_attn_varlen_func (#10811)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
b7954776
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
7 deletions
+21
-7
CMakeLists.txt
CMakeLists.txt
+1
-1
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+17
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+3
-3
No files found.
CMakeLists.txt
View file @
073a4bd1
...
...
@@ -522,7 +522,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
fdf6d72b48aea41f4ae6a89139a453dae554abc8
GIT_TAG
04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
tests/kernels/test_flash_attn.py
View file @
073a4bd1
...
...
@@ -71,6 +71,7 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -81,6 +82,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
...
...
@@ -116,17 +118,22 @@ def test_flash_attn_with_paged_kv(
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
q
=
query
.
unsqueeze
(
1
)
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
output
=
flash_attn_with_kvcache
(
q
=
q
uery
.
unsqueeze
(
1
)
,
q
=
q
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
).
squeeze
(
1
)
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
@@ -141,7 +148,10 @@ def test_flash_attn_with_paged_kv(
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)],
[(
1
,
523
),
(
1
,
37
),
(
1
,
2011
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
...
...
@@ -151,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
...
...
@@ -197,10 +208,12 @@ def test_varlen_with_paged_kv(
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
out
=
torch
.
empty_like
(
query
)
if
use_out
else
None
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
...
...
@@ -211,6 +224,7 @@ def test_varlen_with_paged_kv(
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
output
=
output
if
not
use_out
else
out
ref_output
=
ref_paged_attn
(
query
=
query
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
073a4bd1
...
...
@@ -205,10 +205,12 @@ def unified_v1_flash_attention(
v_scale
,
)
attn_output
=
flash_attn_varlen_func
(
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_len
,
cu_seqlens_k
=
attn_metadata
.
seq_start_loc
,
...
...
@@ -220,8 +222,6 @@ def unified_v1_flash_attention(
block_table
=
attn_metadata
.
block_table
,
softcap
=
logits_soft_cap
,
)
# TODO(woosuk): Remove this unnecessary copy.
output
[:
num_actual_tokens
].
copy_
(
attn_output
)
def
unified_v1_flash_attention_fake
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment