Unverified Commit 4fc5f2f9 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add unit test for triton swa kernel (#8853)

parent 168033d5
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import unittest import unittest
import torch import torch
import torch.nn.functional as F
from sglang.srt.layers.attention.triton_ops.decode_attention import ( from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd, decode_attention_fwd,
...@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ...@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
class TestTritonAttention(CustomTestCase): class TestTritonAttention(CustomTestCase):
def _set_all_seeds(self, seed): def _set_all_seeds(self, seed):
...@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase): ...@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
for value in attention_values: for value in attention_values:
self._test_extend_attention_once(19, 12331, 12, 4, value) self._test_extend_attention_once(19, 12331, 12, 4, value)
def _test_extend_attention_sliding_window_once(
self, B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE
):
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
)
for i in range(B):
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
o_extend_torch = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend_triton,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=max_len_extend,
sliding_window_size=WINDOW_SIZE,
)
extend_attention_fwd_torch(
q_extend,
k_extend,
v_extend,
o_extend_torch,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
WINDOW_SIZE,
)
self.assertTrue(
torch.allclose(o_extend_triton, o_extend_torch, rtol=1e-3, atol=1e-3)
)
def test_extend_attention_sliding_window(self):
window_sizes = [-1, 127]
for window_size in window_sizes:
self._test_extend_attention_sliding_window_once(
19, 12331, 64, 8, 128, window_size
)
def _test_context_attention_once(self, head_dim, is_causal): def _test_context_attention_once(self, head_dim, is_causal):
# Set up a simple test case # Set up a simple test case
num_heads = 4 num_heads = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment