Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
05471f21
Unverified
Commit
05471f21
authored
Jun 24, 2024
by
Liangsheng Yin
Committed by
GitHub
Jun 24, 2024
Browse files
Update test_flashinfer (#560)
parent
1fa15099
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
31 deletions
+82
-31
test/srt/test_flashinfer.py
test/srt/test_flashinfer.py
+82
-31
No files found.
test/srt/test_flashinfer.py
View file @
05471f21
import
flashinfer
import
pytest
import
torch
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
,
redundant_attention
from
sglang.srt.layers.token_attention
import
token_attention_fwd
flashinfer_prefill_wrapper
=
None
flashinfer_decode_wrapper
=
None
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
12
,
37
,
67
])
@
pytest
.
mark
.
parametrize
(
"kv_len"
,
[
54
,
97
])
@
pytest
.
mark
.
parametrize
(
"qo_len"
,
[
37
,
17
])
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"num_qo_heads"
,
[
4
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_qo_heads"
,
[
32
,
4
])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"use_wrapper"
,
[
True
,
False
])
def
test_batch_prefill_with_paged_kv_cache
(
batch_size
,
kv_len
,
...
...
@@ -20,12 +26,13 @@ def test_batch_prefill_with_paged_kv_cache(
num_kv_heads
,
num_qo_heads
,
head_dim
,
use_wrapper
,
):
init_flashinfer
(
num_qo_heads
,
num_kv_heads
)
q
=
torch
.
randn
(
batch_size
*
qo_len
,
num_qo_heads
,
head_dim
).
to
(
0
).
half
()
q_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
qo_len
q
o
_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
qo_len
total_tokens
=
kv_len
*
batch_size
kv_data
=
torch
.
randn
(
total_tokens
,
2
,
num_kv_heads
,
1
,
head_dim
).
to
(
0
).
half
()
kv_data
=
torch
.
randn
(
total_tokens
,
2
,
num_kv_heads
,
head_dim
).
to
(
0
).
half
()
kv_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
kv_len
kv_indices
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
()
kv_last_page_len
=
torch
.
full
((
batch_size
,),
1
,
dtype
=
torch
.
int32
).
to
(
0
)
...
...
@@ -70,21 +77,44 @@ def test_batch_prefill_with_paged_kv_cache(
max_len_extend
,
)
if
use_wrapper
:
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
()
wrapper
.
begin_forward
(
q_indptr
,
batch_size
,
num_qo_heads
,
num_kv_heads
)
o
=
wrapper
.
forward
(
q
,
q_indptr
,
kv_data
,
kv_indptr
,
kv_indices
,
kv_last_page_len
)
else
:
o
=
flashinfer
.
batch_prefill_with_paged_kv_cache
(
q
,
q_indptr
,
kv_data
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
)
o_redundant
=
torch
.
empty_like
(
q
)
b_start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
).
to
(
0
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
dim
=
0
)
b_seq_len_prefix
=
b_seq_len
-
b_seq_len_extend
redundant_attention
(
q
,
k_extend
,
v_extend
,
o_redundant
,
k_buffer
,
v_buffer
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len_prefix
,
max_len_in_batch
,
)
print
(
"Mean: "
,
torch
.
mean
(
torch
.
abs
(
o_redundant
-
o_triton
)))
print
(
"Max: "
,
torch
.
max
(
torch
.
abs
(
o_redundant
-
o_triton
)))
assert
torch
.
allclose
(
o_redundant
,
o_triton
,
rtol
=
1e-2
,
atol
=
1e-3
)
flashinfer_prefill_wrapper
.
end_forward
()
flashinfer_prefill_wrapper
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
o
=
flashinfer_prefill_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
num_qo_heads
,
head_dim
),
kv_data
)
print
(
"Mean: "
,
torch
.
mean
(
torch
.
abs
(
o
-
o_triton
)))
print
(
"Max: "
,
torch
.
max
(
torch
.
abs
(
o
-
o_triton
)))
...
...
@@ -105,10 +135,11 @@ def test_batch_decode_with_paged_kv_cache(
):
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
# to test different shape of decode, change the parameters in the __main__, and run decode only once
init_flashinfer
(
num_qo_heads
,
num_kv_heads
)
q
=
torch
.
randn
(
batch_size
,
num_qo_heads
,
head_dim
).
to
(
0
).
half
()
total_tokens
=
kv_len
*
batch_size
kv_data
=
torch
.
randn
(
total_tokens
,
2
,
num_kv_heads
,
1
,
head_dim
).
to
(
0
).
half
()
kv_data
=
torch
.
randn
(
total_tokens
,
2
,
num_kv_heads
,
head_dim
).
to
(
0
).
half
()
kv_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
kv_len
kv_indices
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
()
kv_last_page_len
=
torch
.
full
((
batch_size
,),
1
,
dtype
=
torch
.
int32
).
to
(
0
)
...
...
@@ -139,26 +170,46 @@ def test_batch_decode_with_paged_kv_cache(
total_tokens
,
)
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
()
wrapper
.
begin_forward
(
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
batch_size
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
"NONE"
,
"float16"
,
pos_encoding_mode
=
"NONE"
,
data_type
=
"float16"
,
)
o
=
flashinfer_decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
num_qo_heads
,
head_dim
),
kv_data
)
o
=
wrapper
.
forward
(
q
,
kv_data
,
kv_indptr
,
kv_indices
,
kv_last_page_len
)
print
(
"Mean: "
,
torch
.
mean
(
torch
.
abs
(
o
-
o_triton
)))
print
(
"Max: "
,
torch
.
max
(
torch
.
abs
(
o
-
o_triton
)))
assert
torch
.
allclose
(
o
,
o_triton
,
rtol
=
1e-2
,
atol
=
2e-3
)
def
init_flashinfer
(
num_attention_heads
,
num_kv_heads
):
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
global
flashinfer_prefill_wrapper
,
flashinfer_decode_wrapper
flashinfer_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
if
__name__
==
"__main__"
:
test_batch_prefill_with_paged_kv_cache
(
12
,
54
,
37
,
8
,
8
,
128
,
False
)
test_batch_prefill_with_paged_kv_cache
(
37
,
1111
,
456
,
32
,
32
,
128
,
True
)
test_batch_prefill_with_paged_kv_cache
(
12
,
54
,
37
,
8
,
8
,
128
)
test_batch_prefill_with_paged_kv_cache
(
37
,
1111
,
456
,
32
,
32
,
128
)
test_batch_decode_with_paged_kv_cache
(
12
,
54
,
4
,
32
,
128
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment