Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c7e3924
Unverified
Commit
9c7e3924
authored
Aug 08, 2025
by
eigen
Committed by
GitHub
Aug 08, 2025
Browse files
bench: add attention sink op benchmark, triton and trtllm-gen [B200] (#8932)
Co-authored-by:
averyhuang
<
averyh@nvidia.com
>
parent
08fab2b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
250 additions
and
0 deletions
+250
-0
benchmark/bench_attention_sink/bench_attention_sink_triton.py
...hmark/bench_attention_sink/bench_attention_sink_triton.py
+250
-0
No files found.
benchmark/bench_attention_sink/bench_attention_sink_triton.py
0 → 100644
View file @
9c7e3924
import
argparse
import
torch
import
triton
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd_grouped
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
extend_attention_fwd
# gpt oss
head_num
=
64
head_dim
=
64
head_kv_num
=
8
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"S"
],
# sequence length on x-axis
x_vals
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
],
x_log
=
True
,
line_arg
=
"B"
,
# batch size as different lines
line_vals
=
[
1
,
8
,
32
,
128
],
line_names
=
[
"B=1"
,
"B=8"
,
"B=32"
,
"B=128"
],
styles
=
[
(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"cyan"
,
"-"
),
],
ylabel
=
"TFLOPS"
,
plot_name
=
"attention-sink-triton-decode"
,
args
=
{},
)
)
def
benchmark_decode
(
B
,
S
,
H_Q
,
H_KV
,
D
):
D_V
=
D
dtype
=
torch
.
bfloat16
seq_len
=
S
total_tokens
=
B
*
seq_len
device
=
torch
.
device
(
"cuda"
)
sm_scale
=
1.0
/
(
D
**
0.5
)
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
k_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len
,
dim
=
0
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse1
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
sink
=
torch
.
randn
(
H_Q
,
device
=
device
,
dtype
=
torch
.
float32
)
# warmup
for
_
in
range
(
5
):
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
kv_indptr
,
kv_indices
,
attn_logits1
,
attn_lse1
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
sinks
=
sink
,
)
# benchmark
run_step
=
500
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
for
_
in
range
(
run_step
):
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
kv_indptr
,
kv_indices
,
attn_logits1
,
attn_lse1
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
sinks
=
sink
,
)
end_event
.
record
()
end_event
.
synchronize
()
torch
.
cuda
.
synchronize
()
ms
=
start_event
.
elapsed_time
(
end_event
)
/
run_step
tflops
=
lambda
ms
:
(
2
*
B
*
S
*
H_Q
*
D
)
*
1e-9
/
ms
# must be causal
return
tflops
(
ms
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"S"
],
# sequence length on x-axis
x_vals
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
],
x_log
=
True
,
line_arg
=
"B"
,
# batch size as different lines
line_vals
=
[
1
,
8
,
32
,
128
],
line_names
=
[
"B=1"
,
"B=8"
,
"B=32"
,
"B=128"
],
styles
=
[
(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
),
(
"cyan"
,
"-"
),
],
ylabel
=
"TFLOPS"
,
plot_name
=
"attention-sink-triton-extend"
,
args
=
{},
)
)
def
benchmark_extend
(
B
,
S
,
H_Q
,
H_KV
,
D
):
# S here represents N_CTX from the test
dtype
=
torch
.
bfloat16
device
=
"cuda"
# Split S into prefix and extend lengths
prefill_len
=
S
//
2
# Similar to test's N_CTX // 2
extend_len
=
S
//
4
# Make extend length smaller than prefix
# Calculate total tokens and extend tokens
total_extend_tokens
=
B
*
extend_len
total_prefix_tokens
=
B
*
prefill_len
# Create query, key, value tensors for extension
q_extend
=
torch
.
randn
(
total_extend_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
device
)
k_extend
=
torch
.
randn
(
total_extend_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
device
)
v_extend
=
torch
.
randn
(
total_extend_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
device
)
o_extend
=
torch
.
empty_like
(
q_extend
)
# Create key-value buffers for prefix
k_buffer
=
torch
.
randn
(
total_prefix_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
device
)
v_buffer
=
torch
.
randn
(
total_prefix_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
device
)
# Create index pointers
qo_indptr
=
torch
.
arange
(
0
,
(
B
+
1
)
*
extend_len
,
extend_len
,
device
=
device
).
to
(
torch
.
int32
)
kv_indptr
=
torch
.
arange
(
0
,
(
B
+
1
)
*
prefill_len
,
prefill_len
,
device
=
device
).
to
(
torch
.
int32
)
kv_indices
=
torch
.
arange
(
0
,
total_prefix_tokens
,
device
=
device
).
to
(
torch
.
int32
)
sm_scale
=
1.0
/
(
D
**
0.5
)
# sliding_window = 128 # From GPT-OSS config, skip for now
sliding_window
=
-
1
sink
=
torch
.
randn
(
H_Q
,
device
=
device
,
dtype
=
torch
.
float32
)
# warmup
for
_
in
range
(
5
):
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
extend_len
,
sm_scale
=
sm_scale
,
sliding_window_size
=
sliding_window
,
sinks
=
sink
,
)
# benchmark
run_step
=
500
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start_event
.
record
()
for
_
in
range
(
run_step
):
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
extend_len
,
sm_scale
=
sm_scale
,
sliding_window_size
=
sliding_window
,
sinks
=
sink
,
)
end_event
.
record
()
end_event
.
synchronize
()
torch
.
cuda
.
synchronize
()
ms
=
start_event
.
elapsed_time
(
end_event
)
/
run_step
# FLOPS calculation: each attention operation requires 2 multiplications per element
total_flops
=
2
*
total_extend_tokens
*
H_Q
*
(
prefill_len
+
extend_len
/
2
)
*
D
tflops
=
lambda
ms
:
total_flops
*
1e-12
/
(
ms
*
1e-3
)
# convert to TFLOPS
return
tflops
(
ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--bench"
,
type
=
str
,
default
=
"all"
,
help
=
"all, extend, decode"
)
args
=
parser
.
parse_args
()
kwargs
=
{
"H_Q"
:
head_num
,
"H_KV"
:
head_kv_num
,
"D"
:
head_dim
,
}
if
args
.
bench
in
[
"all"
,
"decode"
]:
benchmark_decode
.
run
(
print_data
=
True
,
show_plots
=
False
,
**
kwargs
)
if
args
.
bench
in
[
"all"
,
"extend"
]:
benchmark_extend
.
run
(
print_data
=
True
,
show_plots
=
False
,
**
kwargs
)
print
(
"Benchmark finished!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment