Unverified Commit 9c7e3924 authored by eigen's avatar eigen Committed by GitHub
Browse files

bench: add attention sink op benchmark, triton and trtllm-gen [B200] (#8932)


Co-authored-by: default avataraveryhuang <averyh@nvidia.com>
parent 08fab2b0
import argparse
import torch
import triton
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd_grouped,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
# gpt oss
head_num = 64
head_dim = 64
head_kv_num = 8
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["S"], # sequence length on x-axis
x_vals=[128, 256, 512, 1024, 2048, 4096],
x_log=True,
line_arg="B", # batch size as different lines
line_vals=[1, 8, 32, 128],
line_names=["B=1", "B=8", "B=32", "B=128"],
styles=[
("blue", "-"),
("green", "-"),
("red", "-"),
("cyan", "-"),
],
ylabel="TFLOPS",
plot_name="attention-sink-triton-decode",
args={},
)
)
def benchmark_decode(B, S, H_Q, H_KV, D):
D_V = D
dtype = torch.bfloat16
seq_len = S
total_tokens = B * seq_len
device = torch.device("cuda")
sm_scale = 1.0 / (D**0.5)
max_kv_splits = 8
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits1 = torch.empty(
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
attn_lse1 = torch.empty(
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
# warmup
for _ in range(5):
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits1,
attn_lse1,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=sink,
)
# benchmark
run_step = 500
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(run_step):
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits1,
attn_lse1,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=sink,
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms # must be causal
return tflops(ms)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["S"], # sequence length on x-axis
x_vals=[128, 256, 512, 1024, 2048, 4096],
x_log=True,
line_arg="B", # batch size as different lines
line_vals=[1, 8, 32, 128],
line_names=["B=1", "B=8", "B=32", "B=128"],
styles=[
("blue", "-"),
("green", "-"),
("red", "-"),
("cyan", "-"),
],
ylabel="TFLOPS",
plot_name="attention-sink-triton-extend",
args={},
)
)
def benchmark_extend(B, S, H_Q, H_KV, D):
# S here represents N_CTX from the test
dtype = torch.bfloat16
device = "cuda"
# Split S into prefix and extend lengths
prefill_len = S // 2 # Similar to test's N_CTX // 2
extend_len = S // 4 # Make extend length smaller than prefix
# Calculate total tokens and extend tokens
total_extend_tokens = B * extend_len
total_prefix_tokens = B * prefill_len
# Create query, key, value tensors for extension
q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)
k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
o_extend = torch.empty_like(q_extend)
# Create key-value buffers for prefix
k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
# Create index pointers
qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(
torch.int32
)
kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(
torch.int32
)
kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)
sm_scale = 1.0 / (D**0.5)
# sliding_window = 128 # From GPT-OSS config, skip for now
sliding_window = -1
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
# warmup
for _ in range(5):
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=extend_len,
sm_scale=sm_scale,
sliding_window_size=sliding_window,
sinks=sink,
)
# benchmark
run_step = 500
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(run_step):
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=extend_len,
sm_scale=sm_scale,
sliding_window_size=sliding_window,
sinks=sink,
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
# FLOPS calculation: each attention operation requires 2 multiplications per element
total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D
tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3) # convert to TFLOPS
return tflops(ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--bench", type=str, default="all", help="all, extend, decode")
args = parser.parse_args()
kwargs = {
"H_Q": head_num,
"H_KV": head_kv_num,
"D": head_dim,
}
if args.bench in ["all", "decode"]:
benchmark_decode.run(print_data=True, show_plots=False, **kwargs)
if args.bench in ["all", "extend"]:
benchmark_extend.run(print_data=True, show_plots=False, **kwargs)
print("Benchmark finished!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment