Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a1b3de86
Unverified
Commit
a1b3de86
authored
Mar 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 29, 2023
Browse files
Refactor the test code for attention kernels (#13)
parent
64e0e383
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
19 deletions
+53
-19
tests/kernels/attention.py
tests/kernels/attention.py
+53
-19
No files found.
tests/kernels/attention.py
View file @
a1b3de86
import
random
import
random
from
typing
import
Optional
from
typing
import
List
,
Optional
from
flash_attn.flash_attention
import
FlashAttention
from
flash_attn.flash_attention
import
FlashAttention
import
torch
import
torch
...
@@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
...
@@ -64,6 +64,39 @@ def ref_single_query_cached_kv_attention(
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
List
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
head_size
=
query
.
shape
[
-
1
]
scale
=
1.0
/
(
head_size
**
0.5
)
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
ref_output
def
test_single_query_cached_kv_attention
(
def
test_single_query_cached_kv_attention
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
...
@@ -156,30 +189,29 @@ def test_multi_query_kv_attention(
causal
=
True
,
causal
=
True
,
)[
0
]
)[
0
]
ref_outputs
=
[]
cu_seq_lens
=
cu_seq_lens
.
cpu
().
tolist
()
for
i
,
seq_len
in
enumerate
(
seq_lens
):
ref_output
=
ref_multi_query_kv_attention
(
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
diagonal
=
1
)
*
-
1e5
cu_seq_lens
,
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
,
device
=
'cuda'
)
query
,
start_idx
=
cu_seq_lens
[
i
]
key
,
end_idx
=
cu_seq_lens
[
i
+
1
]
value
,
ref_output
=
ref_masked_attention
(
dtype
,
query
[
start_idx
:
end_idx
],
key
[
start_idx
:
end_idx
],
value
[
start_idx
:
end_idx
],
scale
,
attn_mask
=
attn_mask
,
)
)
ref_outputs
.
append
(
ref_output
)
ref_output
=
torch
.
cat
(
ref_outputs
,
dim
=
0
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_attention
()
->
None
:
def
test_attention
(
seed
:
int
)
->
None
:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
]:
for
block_size
in
[
8
,
16
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
test_single_query_cached_kv_attention
(
test_single_query_cached_kv_attention
(
num_tokens
=
37
,
num_tokens
=
37
,
num_heads
=
3
,
num_heads
=
3
,
...
@@ -193,6 +225,8 @@ def test_attention() -> None:
...
@@ -193,6 +225,8 @@ def test_attention() -> None:
for
dtype
in
[
torch
.
half
]:
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for
head_size
in
[
64
,
80
,
96
,
128
]:
for
head_size
in
[
64
,
80
,
96
,
128
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
test_multi_query_kv_attention
(
test_multi_query_kv_attention
(
num_seqs
=
11
,
num_seqs
=
11
,
num_heads
=
3
,
num_heads
=
3
,
...
@@ -202,4 +236,4 @@ def test_attention() -> None:
...
@@ -202,4 +236,4 @@ def test_attention() -> None:
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_attention
()
test_attention
(
seed
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment