Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e2102669
Unverified
Commit
e2102669
authored
Dec 17, 2024
by
bjmsong
Committed by
GitHub
Dec 17, 2024
Browse files
benchmark decoding attention kernel with cudnn (#2467)
Co-authored-by:
root
<
bjmsong@126.com
>
parent
bd619616
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
405 additions
and
172 deletions
+405
-172
benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py
.../decoding_attention_triton/sglang_triton_vs_flashinfer.py
+0
-172
benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
...nels/decoding_attention_triton/triton_flashinfer_cudnn.py
+405
-0
No files found.
benchmark/kernels/decoding_attention_triton/sglang_triton_vs_flashinfer.py
deleted
100644 → 0
View file @
bd619616
import
itertools
import
torch
import
triton
import
triton.language
as
tl
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
decode_attention_fwd
def
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
):
k_buffer
=
kv_data
[:,
0
].
view
(
-
1
,
head_num_kv
,
head_dim
).
contiguous
()
v_buffer
=
kv_data
[:,
1
].
view
(
-
1
,
head_num_kv
,
head_dim
).
contiguous
()
o
=
torch
.
empty_like
(
q
)
total_tokens
=
batch_size
*
kv_len
req_to_token
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
().
view
(
batch_size
,
kv_len
)
b_req_idx
=
torch
.
arange
(
0
,
batch_size
).
to
(
0
).
int
()
b_seq_len
=
torch
.
full
((
batch_size
,),
kv_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_len_in_batch
=
kv_len
sm_scale
=
1.0
/
(
head_dim
**
0.5
)
attn_logits
=
torch
.
empty
(
(
batch_size
,
head_num_q
,
num_kv_splits
,
head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
return
o
def
decode_attention_flashinfer
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
):
total_tokens
=
batch_size
*
kv_len
kv_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
kv_len
kv_indices
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
()
kv_last_page_len
=
torch
.
full
((
batch_size
,),
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
head_num_q
,
head_num_kv
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
dtype
,
)
o
=
flashinfer_decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
head_num_q
,
head_dim
),
kv_data
)
return
o
def
calculate_diff
():
dtype
=
torch
.
bfloat16
batch_size
=
4
kv_len
=
16
head_num_q
=
32
head_num_kv
=
32
head_dim
=
128
q
=
torch
.
randn
(
batch_size
,
head_num_q
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
kv_data
=
torch
.
randn
(
batch_size
*
kv_len
,
2
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
output_sglang
=
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
=
8
,
)
output_flashinfer
=
decode_attention_flashinfer
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
=
dtype
)
print
(
f
"SGLang output=
{
output_sglang
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
if
torch
.
allclose
(
output_sglang
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ SGLang[Triton] and FlashInfer match"
)
else
:
print
(
"❌ SGLang[Triton] and FlashInfer differ"
)
head_dim
=
128
dtype
=
torch
.
float16
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
kv_len_range
=
[
2
**
i
for
i
in
range
(
6
,
13
,
1
)]
head_num_range
=
[
32
,
64
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
kv_len_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"kv_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"sglang_triton"
,
"flashinfer"
],
line_names
=
[
"SGLang[triton]"
,
"FlashInfer"
],
styles
=
[(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"decode-attention-performance"
,
args
=
{},
)
)
def
benchmark
(
head_num
,
batch_size
,
kv_len
,
provider
):
head_num_q
=
head_num_kv
=
head_num
q
=
torch
.
randn
(
batch_size
,
head_num_q
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
kv_data
=
torch
.
randn
(
batch_size
*
kv_len
,
2
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sglang_triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
=
8
,
),
quantiles
=
quantiles
,
)
if
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
decode_attention_flashinfer
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
global
flashinfer_decode_wrapper
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
False
)
calculate_diff
()
benchmark
.
run
(
print_data
=
True
)
benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
0 → 100644
View file @
e2102669
import
itertools
import
math
import
cudnn
import
torch
import
torch.utils.benchmark
as
benchmark
import
triton
import
triton.language
as
tl
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
decode_attention_fwd
from
sglang.srt.utils
import
should_use_tensor_core
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
,
):
def
amp_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
"fn_amp(*inputs, **kwinputs)"
,
globals
=
{
"fn_amp"
:
amp_wrapper
,
"inputs"
:
inputs
,
"kwinputs"
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
return
t
,
m
def
time_fwd
(
func
,
*
args
,
**
kwargs
):
time_f
=
benchmark_forward
(
func
,
*
args
,
**
kwargs
)
return
time_f
[
1
].
mean
*
1e6
def
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
,
warmup
=
10
,
):
k_buffer
=
kv_data
[
0
].
view
(
-
1
,
head_num_kv
,
head_dim
)
v_buffer
=
kv_data
[
1
].
view
(
-
1
,
head_num_kv
,
head_dim
)
o
=
torch
.
empty_like
(
q
)
total_tokens
=
batch_size
*
kv_len
req_to_token
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
().
view
(
batch_size
,
kv_len
)
b_req_idx
=
torch
.
arange
(
0
,
batch_size
).
to
(
0
).
int
()
b_seq_len
=
torch
.
full
((
batch_size
,),
kv_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_len_in_batch
=
kv_len
sm_scale
=
1.0
/
(
head_dim
**
0.5
)
attn_logits
=
torch
.
empty
(
(
batch_size
,
head_num_q
,
num_kv_splits
,
head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
for
_
in
range
(
warmup
):
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
f
=
time_fwd
(
decode_attention_fwd
,
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
return
f
,
o
def
decode_attention_flashinfer
(
dtype
,
head_num_q
,
head_num_kv
):
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
use_tensor_cores
=
should_use_tensor_core
(
kv_cache_dtype
=
dtype
,
num_attention_heads
=
head_num_q
,
num_kv_heads
=
head_num_kv
,
)
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
class
FlashinferAttention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
,
warmup
=
10
,
):
total_tokens
=
batch_size
*
kv_len
kv_indptr
=
torch
.
arange
(
0
,
batch_size
+
1
).
to
(
0
).
int
()
*
kv_len
kv_indices
=
torch
.
arange
(
0
,
total_tokens
).
to
(
0
).
int
()
kv_last_page_len
=
torch
.
full
(
(
batch_size
,),
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
head_num_q
,
head_num_kv
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
dtype
,
)
for
_
in
range
(
warmup
):
o
=
flashinfer_decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
head_num_q
,
head_dim
),
kv_data
)
f
=
time_fwd
(
flashinfer_decode_wrapper
.
forward
,
q
.
contiguous
().
view
(
-
1
,
head_num_q
,
head_dim
),
kv_data
,
)
return
f
,
o
return
FlashinferAttention
def
convert_to_cudnn_type
(
torch_type
):
if
torch_type
==
torch
.
float16
:
return
cudnn
.
data_type
.
HALF
elif
torch_type
==
torch
.
bfloat16
:
return
cudnn
.
data_type
.
BFLOAT16
elif
torch_type
==
torch
.
float32
:
return
cudnn
.
data_type
.
FLOAT
elif
torch_type
==
torch
.
int32
:
return
cudnn
.
data_type
.
INT32
elif
torch_type
==
torch
.
int64
:
return
cudnn
.
data_type
.
INT64
else
:
raise
ValueError
(
"Unsupported tensor data type."
)
def
decode_attention_cudnn
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
,
warmup
=
10
):
# Prepare data: continuous q,k,v
dims_q
=
(
batch_size
,
head_num_q
,
1
,
head_dim
)
strides_q
=
(
head_num_q
*
head_dim
,
head_dim
,
head_num_q
*
head_dim
,
1
)
q_gpu
=
q
.
as_strided
(
dims_q
,
strides_q
)
o_gpu
=
(
torch
.
empty
(
batch_size
*
head_num_q
*
head_dim
)
.
half
()
.
cuda
()
.
as_strided
(
dims_q
,
strides_q
)
)
dims_kv
=
(
batch_size
,
head_num_kv
,
kv_len
,
head_dim
)
strides_kv
=
(
kv_len
*
head_num_kv
*
head_dim
,
head_dim
,
head_num_kv
*
head_dim
,
1
,
)
k_gpu
=
kv_data
[
0
].
as_strided
(
dims_kv
,
strides_kv
)
v_gpu
=
kv_data
[
1
].
as_strided
(
dims_kv
,
strides_kv
)
seq_len_q_gpu
=
torch
.
full
((
batch_size
,
1
,
1
,
1
),
1
,
device
=
"cuda"
)
seq_len_kv_gpu
=
torch
.
full
((
batch_size
,
1
,
1
,
1
),
kv_len
,
device
=
"cuda"
)
attn_scale
=
1.0
/
(
head_dim
**
0.5
)
# Prepare data: paged k,v
block_size
=
1
blocks_per_batch
=
math
.
ceil
(
kv_len
/
block_size
)
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
container_k_gpu
=
torch
.
cat
(
k_gpu
.
chunk
(
blocks_per_batch
,
dim
=
2
),
dim
=
0
)
container_v_gpu
=
torch
.
cat
(
v_gpu
.
chunk
(
blocks_per_batch
,
dim
=
2
),
dim
=
0
)
page_table_k_gpu
=
(
torch
.
linspace
(
0
,
batch_size
*
blocks_per_batch
-
1
,
batch_size
*
blocks_per_batch
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
)
.
reshape
(
blocks_per_batch
,
1
,
batch_size
,
1
)
.
transpose
(
0
,
2
)
)
page_table_v_gpu
=
page_table_k_gpu
.
clone
()
graph
=
cudnn
.
pygraph
(
io_data_type
=
convert_to_cudnn_type
(
dtype
),
intermediate_data_type
=
cudnn
.
data_type
.
FLOAT
,
compute_data_type
=
cudnn
.
data_type
.
FLOAT
,
)
q
=
graph
.
tensor_like
(
q_gpu
)
container_k
=
graph
.
tensor_like
(
container_k_gpu
)
container_v
=
graph
.
tensor_like
(
container_v_gpu
)
page_table_k
=
graph
.
tensor_like
(
page_table_k_gpu
)
page_table_v
=
graph
.
tensor_like
(
page_table_v_gpu
)
seq_len_q
=
graph
.
tensor_like
(
seq_len_q_gpu
)
seq_len_kv
=
graph
.
tensor_like
(
seq_len_kv_gpu
)
o
,
_
=
graph
.
sdpa
(
name
=
"sdpa"
,
q
=
q
,
k
=
container_k
,
# Container K: non contiguous container with K blocks
v
=
container_v
,
# Container V: non contiguous container with V blocks
is_inference
=
True
,
attn_scale
=
attn_scale
,
use_causal_mask
=
False
,
use_padding_mask
=
True
,
seq_len_q
=
seq_len_q
,
seq_len_kv
=
seq_len_kv
,
paged_attention_k_table
=
page_table_k
,
# Page Table K: Tensor containing offsets to the container with K blocks
paged_attention_v_table
=
page_table_v
,
# Page Table V: Tensor containing offsets to the container with V blocks
paged_attention_max_seq_len_kv
=
kv_len
,
# The maximum sequence length for K caches (this is optional, but recommended)
)
o
.
set_output
(
True
).
set_dim
(
dims_q
).
set_stride
(
strides_q
)
graph
.
validate
()
graph
.
build_operation_graph
()
graph
.
create_execution_plans
([
cudnn
.
heur_mode
.
A
])
graph
.
check_support
()
graph
.
build_plans
()
workspace
=
torch
.
empty
(
graph
.
get_workspace_size
(),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
variant_pack
=
{
q
:
q_gpu
,
container_k
:
container_k_gpu
,
container_v
:
container_v_gpu
,
page_table_k
:
page_table_k_gpu
,
page_table_v
:
page_table_v_gpu
,
seq_len_q
:
seq_len_q_gpu
,
seq_len_kv
:
seq_len_kv_gpu
,
o
:
o_gpu
,
}
for
_
in
range
(
warmup
):
graph
.
execute
(
variant_pack
,
workspace
)
f
=
time_fwd
(
graph
.
execute
,
variant_pack
,
workspace
,
)
return
f
,
o_gpu
.
squeeze
(
dim
=
2
)
def
calculate_diff
():
dtype
=
torch
.
float16
batch_size
=
64
kv_len
=
4096
head_num_q
=
64
head_num_kv
=
8
head_dim
=
128
q
=
torch
.
randn
(
batch_size
,
head_num_q
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
kv_data
=
(
torch
.
randn
(
batch_size
*
kv_len
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
),
torch
.
randn
(
batch_size
*
kv_len
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
),
)
_
,
output_sglang
=
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
=
8
,
)
attn_flashinfer
=
decode_attention_flashinfer
(
dtype
,
head_num_q
,
head_num_kv
).
apply
_
,
output_flashinfer
=
attn_flashinfer
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
)
_
,
output_cudnn
=
decode_attention_cudnn
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
)
print
(
f
"SGLang output=
{
output_sglang
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"cuDNN output=
{
output_cudnn
}
"
)
if
torch
.
allclose
(
output_sglang
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ SGLang[Triton] and FlashInfer match"
)
else
:
print
(
"❌ SGLang[Triton] and FlashInfer differ"
)
if
torch
.
allclose
(
output_sglang
,
output_cudnn
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ SGLang[Triton] and cuDNN match"
)
else
:
print
(
"❌ SGLang[Triton] and cuDNN differ"
)
if
__name__
==
"__main__"
:
calculate_diff
()
head_dim
=
128
dtype
=
torch
.
float16
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
,
2
)]
kv_len_range
=
[
2
**
i
for
i
in
range
(
6
,
13
,
1
)]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
kv_len_range
))
for
head_num_q
,
head_num_kv
in
[[
32
,
32
],
[
64
,
8
],
[
40
,
8
]]:
attn_flashinfer
=
decode_attention_flashinfer
(
dtype
,
head_num_q
,
head_num_kv
).
apply
for
batch_size
,
kv_len
in
configs
:
q
=
torch
.
randn
(
batch_size
,
head_num_q
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
kv_data
=
(
torch
.
randn
(
batch_size
*
kv_len
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
,
),
torch
.
randn
(
batch_size
*
kv_len
,
head_num_kv
,
head_dim
,
dtype
=
dtype
,
device
=
"cuda"
,
),
)
us_cudnn
,
output_cudnn
=
decode_attention_cudnn
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
)
us_sglang
,
output_sglang
=
decode_attention_sglang
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
num_kv_splits
=
8
,
)
us_flashinfer
,
_
=
attn_flashinfer
(
q
,
kv_data
,
batch_size
,
kv_len
,
head_num_q
,
head_num_kv
,
head_dim
,
dtype
)
print
(
head_num_q
,
" "
,
head_num_kv
,
" "
,
batch_size
,
" "
,
kv_len
,
" "
,
us_cudnn
,
" "
,
us_sglang
,
" "
,
us_flashinfer
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment