Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
a1b3de86
Unverified
Commit
a1b3de86
authored
Mar 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 29, 2023
Browse files
Refactor the test code for attention kernels (#13)
parent
64e0e383
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
19 deletions
+53
-19
tests/kernels/attention.py
tests/kernels/attention.py
+53
-19
No files found.
tests/kernels/attention.py
View file @
a1b3de86
import
random
from
typing
import
Optional
from
typing
import
List
,
Optional
from
flash_attn.flash_attention
import
FlashAttention
import
torch
...
...
@@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
head_size
=
query
.
shape
[
-
1
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
def
test_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_heads
:
int
,
...
...
@@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
causal
=
True
,
)[
0
]
ref_outputs
=
[]
for
i
,
seq_len
in
enumerate
(
seq_lens
):
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
ref_output
=
ref_multi_query_kv_attention
(
cu_seq_lens
,
query
,
key
,
value
,
dtype
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
def
test_attention
()
->
None
:
def
test_attention
(
seed
:
int
)
->
None
:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
test_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_heads
=
3
,
...
...
@@ -193,6 +225,8 @@ def test_attention() -> None:
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_heads
=
3
,
...
...
@@ -202,4 +236,4 @@ def test_attention() -> None:
if
__name__
==
'__main__'
:
test_attention
()
test_attention
(
seed
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment