You need to sign in or sign up before continuing.
Unverified Commit 4fc5f2f9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add unit test for triton swa kernel (#8853)

parent 168033d5
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import unittest import unittest
import torch import torch
import torch.nn.functional as F
from sglang.srt.layers.attention.triton_ops.decode_attention import ( from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd, decode_attention_fwd,
...@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ...@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
class TestTritonAttention(CustomTestCase): class TestTritonAttention(CustomTestCase):
def _set_all_seeds(self, seed): def _set_all_seeds(self, seed):
...@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase): ...@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
for value in attention_values: for value in attention_values:
self._test_extend_attention_once(19, 12331, 12, 4, value) self._test_extend_attention_once(19, 12331, 12, 4, value)
def _test_extend_attention_sliding_window_once(
self, B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE
):
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
o_extend_torch = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend_triton,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=max_len_extend,
sliding_window_size=WINDOW_SIZE,
)
extend_attention_fwd_torch(
q_extend,
k_extend,
v_extend,
o_extend_torch,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
WINDOW_SIZE,
)
self.assertTrue(
torch.allclose(o_extend_triton, o_extend_torch, rtol=1e-3, atol=1e-3)
)
def test_extend_attention_sliding_window(self):
window_sizes = [-1, 127]
for window_size in window_sizes:
self._test_extend_attention_sliding_window_once(
19, 12331, 64, 8, 128, window_size
)
def _test_context_attention_once(self, head_dim, is_causal): def _test_context_attention_once(self, head_dim, is_causal):
# Set up a simple test case # Set up a simple test case
num_heads = 4 num_heads = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment