Unverified Commit 0475448e authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Optimize triton swa kernel by skipping computation (#8860)

parent 399e7ec8
import itertools
import torch
import torch.nn.functional as F
import triton.testing as tt
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
def _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
):
b_seq_len_prefix = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len_extend = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
)
for i in range(B):
s = kv_indptr[i].item()
e = kv_indptr[i + 1].item()
kv_indices[s:e] = torch.arange(
b_start_loc[i],
b_start_loc[i] + b_seq_len_prefix[i],
dtype=torch.int32,
device=device,
)
total_token_num = int(torch.sum(b_seq_len).item())
extend_token_num = int(torch.sum(b_seq_len_extend).item())
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device=device
)
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
inputs = dict(
q_extend=q_extend,
k_extend=k_extend,
v_extend=v_extend,
k_buffer=k_buffer,
v_buffer=v_buffer,
o_extend_triton=o_extend_triton,
o_extend_torch=o_extend_torch,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
max_len_extend=max_len_extend,
WINDOW_SIZE=WINDOW_SIZE,
)
meta = dict(
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
)
return inputs, meta
def _run_triton(inputs):
extend_attention_fwd(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_triton"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=inputs["max_len_extend"],
sliding_window_size=inputs["WINDOW_SIZE"],
)
def _run_torch_ref(inputs):
extend_attention_fwd_torch(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_torch"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
inputs["WINDOW_SIZE"],
)
N_CTXS = [1024, 2048, 4096, 8192]
WINDOW_SIZES = [-1, 127, 256, 512]
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
PROVIDERS = ["torch", "triton"]
@tt.perf_report(
tt.Benchmark(
x_names=["N_CTX", "WINDOW_SIZE"],
x_vals=CONFIGS,
line_arg="provider",
line_vals=PROVIDERS,
line_names=PROVIDERS,
ylabel="Runtime (ms)",
plot_name="extend_attention_triton_vs_torch",
args={
"B": 32,
"H_Q": 64,
"H_KV": 8,
"D": 128,
"dtype": "bf16",
"device": "cuda",
"check_correctness": False,
"warmup": 25,
"rep": 100,
},
)
)
def bench(
N_CTX,
provider,
B,
H_Q,
H_KV,
D,
dtype,
device,
WINDOW_SIZE,
check_correctness,
warmup,
rep,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
dt = dtype_map[dtype]
inputs, _ = _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
)
if check_correctness and provider == "triton":
_run_triton(inputs)
_run_torch_ref(inputs)
torch.cuda.synchronize()
if not torch.allclose(
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
):
raise AssertionError("Mismatch between triton and torch reference.")
if provider == "triton":
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
elif provider == "torch":
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
else:
raise ValueError(provider)
return ms
if __name__ == "__main__":
bench.run(print_data=True, show_plots=False)
......@@ -134,38 +134,6 @@ def _fwd_kernel(
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
)
# load k in transposed way
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
custom_mask = tl.load(
......@@ -185,28 +153,72 @@ def _fwd_kernel(
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
SKIP_TILE = False
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
if not SKIP_TILE:
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
mask=mask_n,
other=0,
)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
# load k in transposed way
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(mask_n[None, :]) & (mask_d[:, None]),
other=0.0,
)
e_max = n_e_max
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=mask_n[:, None] & mask_dv[None, :],
other=0.0,
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
# stage 2: compute the triangle part
......@@ -219,35 +231,6 @@ def _fwd_kernel(
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
# load k in transposed way
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q, k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
custom_mask = tl.load(
......@@ -279,28 +262,62 @@ def _fwd_kernel(
)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
SKIP_TILE = False
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
if not SKIP_TILE:
# load k in transposed way
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
qk = tl.dot(q, k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
e_max = n_e_max
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment