# SPDX-License-Identifier: MIT import torch import torch.nn.functional as F import pytest from aiter.ops.triton.extend_attention import extend_attention_fwd def extend_attention_fwd_torch_swa( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, sliding_window_size: int, *, k_scale: float = 1.0, v_scale: float = 1.0, sm_scale: float | None = None, ): """Reference for causal + sliding-window extend attention (sglang test style). Runs the heavy matmul/softmax on CPU float32 for numerical stability and to avoid ROCm aborts on large bf16 einsum after GPU kernels. v2 与 Triton 一致:``k_scale`` / ``v_scale`` **只作用在 prefix(cache)键位**; extend 段 logits 与 V 不额外乘这两个标量。 """ B = qo_indptr.size(0) - 1 _, H_Q, D = q.shape _, H_KV, _ = k.shape group_size = H_Q // H_KV scale = float(sm_scale) if sm_scale is not None else 1.0 / D**0.5 out_dev = o.device out_dtype = o.dtype for i in range(B): q_start = int(qo_indptr[i].item()) q_end = int(qo_indptr[i + 1].item()) kv_start = int(kv_indptr[i].item()) kv_end = int(kv_indptr[i + 1].item()) prefix_indices = kv_indices[kv_start:kv_end] k_prefix = k_cache[prefix_indices] v_prefix = v_cache[prefix_indices] k_extend = k[q_start:q_end] v_extend = v[q_start:q_end] q_extend = q[q_start:q_end] k_full = torch.cat([k_prefix, k_extend], dim=0) v_full = torch.cat([v_prefix, v_extend], dim=0) if group_size != 1: k_full_hq = k_full.repeat_interleave(group_size, dim=1) v_full_hq = v_full.repeat_interleave(group_size, dim=1) else: k_full_hq = k_full v_full_hq = v_full prefix_len = k_prefix.size(0) extend_len = k_extend.size(0) total_len = prefix_len + extend_len q_e = q_extend.detach().float().cpu() k_h = k_full_hq.detach().float().cpu() v_h = v_full_hq.detach().float().cpu() pos_keys = torch.arange(total_len) t = prefix_len + torch.arange(extend_len) causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1) if sliding_window_size is not None and sliding_window_size > 0: start = (t - sliding_window_size).clamp_min(0) else: start = torch.zeros_like(t) window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1) final_mask = causal_mask & window_mask attn_scores = torch.einsum("qhd,khd->qhk", q_e, k_h) * scale if k_scale != 1.0: attn_scores[:, :, :prefix_len] = attn_scores[:, :, :prefix_len] * k_scale attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf")) attn_weights = F.softmax(attn_scores, dim=-1) if v_scale != 1.0: v_prefix = v_h[:prefix_len] * v_scale v_h_scaled = torch.cat([v_prefix, v_h[prefix_len:]], dim=0) else: v_h_scaled = v_h out_cpu = torch.einsum("qhk,khd->qhd", attn_weights, v_h_scaled) o[q_start:q_end] = out_cpu.to(device=out_dev, dtype=out_dtype) def input_helper( B, H, prefix_length, extend_length, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, attn_impl="normal", equal_seqlens=False, requires_grad=False, kv_num_heads: int = 1, ): torch.manual_seed(0) if not equal_seqlens: max_extend_length = extend_length max_prefix_length = prefix_length seqlens_extend = torch.randint( 1, max_extend_length + 1, (B,), dtype=torch.int32 ) if prefix_length == 0: seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) else: seqlens_prefix = torch.randint( 1, max_prefix_length + 1, (B,), dtype=torch.int32 ) else: seqlens_extend = torch.full((B,), extend_length, dtype=torch.int32) seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) cu_seqlens_extend = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_extend.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_prefix = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_prefix.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_extend = cu_seqlens_extend.to(device="cuda") cu_seqlens_prefix = cu_seqlens_prefix.to(device="cuda") total_extend = cu_seqlens_extend[-1].item() total_prefix = cu_seqlens_prefix[-1].item() if attn_impl == "absorb": Lq = kv_lora_rank + qk_rope_head_dim Lk = kv_lora_rank + qk_rope_head_dim Lv = kv_lora_rank else: Lq = v_head_dim + qk_rope_head_dim Lk = v_head_dim + qk_rope_head_dim Lv = v_head_dim q_extend = torch.randn( total_extend, H, Lq, dtype=dtype, device=device ).requires_grad_(requires_grad) # extend parts k_extend = torch.randn( total_extend, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_extend = k_extend[..., :Lv] # extend indexing qo_indptr = cu_seqlens_extend # prefix parts k_buffer = torch.randn( total_prefix, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_buffer = k_buffer[..., :Lv] if attn_impl != "absorb": # simulate v = kv_latent * w_vc which changes the values compared to k v_extend = torch.randn_like(v_extend, dtype=v_extend.dtype) v_buffer = torch.randn_like(v_buffer, dtype=v_buffer.dtype) # prefix indexing kv_indptr = cu_seqlens_prefix kv_indices = torch.arange(total_prefix, device=device, dtype=torch.int32) custom_mask = None mask_indptr = None max_len_extend = extend_length return ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) def _v2_flat_causal_custom_mask(B, prefix_len, extend_len, device): """Row-major causal mask per batch: [extend_len, prefix_len + extend_len], k_global <= q_global.""" total = prefix_len + extend_len q_row = torch.arange(extend_len, device=device) + prefix_len k_col = torch.arange(total, device=device) m = k_col.unsqueeze(0) <= q_row.unsqueeze(1) one_batch = m.reshape(-1).contiguous() custom_mask = one_batch.repeat(B) seg = extend_len * total mask_indptr = torch.arange( 0, (B + 1) * seg, seg, dtype=torch.int32, device=device ) return custom_mask, mask_indptr def _v2_flat_causal_custom_mask_from_indptr(qo_indptr, kv_indptr, device): """Per-sequence causal mask aligned with ``qo_indptr`` / ``kv_indptr`` (variable lengths). Batch ``b`` contributes ``extend_b * (prefix_b + extend_b)`` bools, row-major ``[extend_b, prefix_b + extend_b]`` with ``k_global <= q_global`` (same as :func:`_v2_flat_causal_custom_mask` for fixed lengths). """ B = qo_indptr.shape[0] - 1 segs = [] mask_indptr = torch.empty(B + 1, dtype=torch.int32, device=device) mask_indptr[0] = 0 pos = 0 for b in range(B): extend_b = int(qo_indptr[b + 1].item() - qo_indptr[b].item()) prefix_b = int(kv_indptr[b + 1].item() - kv_indptr[b].item()) total_b = prefix_b + extend_b q_row = torch.arange(extend_b, device=device) + prefix_b k_col = torch.arange(total_b, device=device) m = k_col.unsqueeze(0) <= q_row.unsqueeze(1) flat = m.reshape(-1).contiguous() segs.append(flat) pos += flat.numel() mask_indptr[b + 1] = pos custom_mask = torch.cat(segs, dim=0) if segs else torch.empty(0, dtype=torch.bool, device=device) return custom_mask, mask_indptr @pytest.mark.parametrize( "B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim", [ (2, 4, 0, 512, 32, 16, 32), (3, 5, 0, 333, 18, 13, 17), (3, 5, 512, 333, 18, 0, 17), (3, 5, 110, 333, 18, 0, 19), # (8, 16, 0, 1024, 128, 0, 128), # this one passes # (8, 16, 0, 16324, 128, 0, 128), # this one fails, numeric precision is likely the issue (2, 1, 64, 32, 128, 64, 128), (2, 1, 64, 32, 128, 64, 128), (4, 16, 64, 96, 128, 64, 128), (1, 16, 0, 7, 512, 64, 512), (1, 16, 7, 4, 512, 64, 512), (1, 16, 32, 4, 512, 64, 512), (1, 16, 64, 3, 512, 64, 512), (1, 16, 127, 4, 512, 64, 512), (1, 16, 255, 15, 512, 64, 512), (3, 16, 452, 16, 512, 64, 512), (4, 16, 512, 14, 512, 64, 512), (4, 16, 1024, 16, 512, 64, 512), (4, 16, 2048, 13, 512, 64, 512), ], ) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("ref_attn_impl", ["normal", "absorb"]) def test_op_fwd( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, ref_attn_impl, causal, sm_scale=1.0, logit_cap=0.0, device="cuda", ): torch.manual_seed(0) torch.set_default_device(device) torch.set_default_dtype(dtype) ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, ) tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) # Reference extend_attention_fwd( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, causal, mask_indptr, max_len_extend, sm_scale=sm_scale, logit_cap=logit_cap, ) ref_out = torch.empty_like(tri_out, dtype=q_extend.dtype, device=q_extend.device) # ref implementation for i in range(0, B): start_q, start_k = qo_indptr[i], kv_indptr[i] end_q, end_k = qo_indptr[i + 1], kv_indptr[i + 1] # Get query, prefix key/values, and extend key/values q = q_extend[start_q:end_q] # [seq_len, H, C] k_prefix = k_buffer[start_k:end_k] # [prefix_len, 1, C] v_prefix = v_buffer[start_k:end_k] # [prefix_len, 1, C] k_ext = k_extend[start_q:end_q] # [seq_len, 1, C] v_ext = v_extend[start_q:end_q] # [seq_len, 1, C] prefix_len = end_k - start_k seq_len = end_q - start_q # Calculate attention scores for prefix tokens scores_prefix = torch.einsum( "qhc,khc->hqk", q.float(), k_prefix.float() ) # .float() # Calculate attention scores for extend tokens scores_extend = torch.einsum( "qhc,khc->hqk", q.float(), k_ext.float() ) # .float() # Apply causal mask only to the extend part if needed if causal: causal_mask = torch.triu( torch.ones( (seq_len, seq_len), dtype=torch.bool, device=scores_extend.device ), diagonal=1, ) causal_mask = causal_mask.unsqueeze(0).expand( scores_extend.shape[0], -1, -1 ) scores_extend = scores_extend.masked_fill(causal_mask, float("-inf")) # Combine scores and apply softmax scores_combined = torch.cat([scores_prefix, scores_extend], dim=-1) * sm_scale p_combined = torch.softmax(scores_combined, dim=-1).to(dtype) # Split the attention weights back p_prefix = p_combined[:, :, :prefix_len] p_extend = p_combined[:, :, prefix_len:] # Calculate output separately and combine out_prefix = torch.einsum( "hqk,khd->qhd", p_prefix.to(dtype).float(), v_prefix.float() ) out_extend = torch.einsum( "hqk,khd->qhd", p_extend.to(dtype).float(), v_ext.float() ) ref_out[start_q:end_q] = out_prefix.to(dtype) + out_extend.to(dtype) torch.testing.assert_close(ref_out, tri_out, rtol=2e-2, atol=2e-2) @pytest.mark.parametrize("prefix_length", [512]) @pytest.mark.parametrize("extend_length", [1, 3, 8, 32]) def test_extend_attention_v2_identity_scales_match_v1(prefix_length, extend_length): """v2 with k_scale=v_scale=1 should match v1. For ``extend_length`` (== passed ``max_len_extend``) < 32 the forward uses ``_fwd_kernel_v2_decode``; for ``extend_length`` >= 32 it uses ``_fwd_kernel_v2``. """ device = "cuda" dtype = torch.float16 torch.manual_seed(0) ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) = input_helper( 2, 8, prefix_length, extend_length, 128, 64, 128, dtype, device, "normal" ) out_v1 = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=device, ) out_v2 = torch.empty_like(out_v1) extend_attention_fwd( q_extend, k_extend, v_extend, out_v1, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, True, mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, config=None, ) extend_attention_fwd( q_extend, k_extend, v_extend, out_v2, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, True, mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, config=None, k_scale=1.0, v_scale=1.0, sliding_window_size=-1, sinks=None, window_kv_offsets=None, xai_temperature_len=-1, ) torch.testing.assert_close(out_v1, out_v2, rtol=2e-2, atol=2e-2) @pytest.mark.parametrize( "prefix,extend_len", [ (57, 4), # align with tune decode bench (B=64, prefix=256, extend=4) (256, 4), # align with tune decode bench (B=64, prefix=256, extend=4) (933, 3), # non-power-of-two prefix + odd extend (512, 1), (127, 8), (1024, 16), (2047, 31), # near decode path upper bound (max_len_extend < 32) ], ) @pytest.mark.parametrize( "use_custom_mask,sliding_window_size,has_sink,kv_num_heads", [ (False, 128, True, 2), (False, -1, False, 1), (True, 128, True, 2), (True, -1, False, 1), ], ) def test_extend_attention_v2_decode_matches_v2_prefill_tuning_shapes( prefix, extend_len, use_custom_mask, sliding_window_size, has_sink, kv_num_heads, ): """``_fwd_kernel_v2_decode`` vs forced ``_fwd_kernel_v2`` on shared (B,H) with varied prefix/extend. ``prefix`` / ``extend_len`` cover the tune decode row (256, 4) plus extra shapes; all cases keep ``extend_len < 32`` so the non-forced path stays on :func:`_fwd_kernel_v2_decode`. """ device = "cuda" dtype = torch.bfloat16 torch.manual_seed(42) B, H = 64, 16 kv_lora_rank, qk_rope_head_dim, v_head_dim = 128, 64, 128 assert H % kv_num_heads == 0 ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask_t, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend_len, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, "normal", equal_seqlens=False, kv_num_heads=kv_num_heads, ) if use_custom_mask: custom_mask_t, mask_indptr = _v2_flat_causal_custom_mask_from_indptr( qo_indptr, kv_indptr, device ) window_kv_offsets = None if use_custom_mask and sliding_window_size > 0: window_kv_offsets = torch.zeros(B, dtype=torch.int32, device=device) sinks = ( torch.randn(H, dtype=dtype, device=device) if has_sink else None ) # out_decode = torch.empty( # (*q_extend.shape[:-1], v_extend.shape[-1]), # dtype=q_extend.dtype, # device=device, # ) # out_prefill = torch.empty_like(out_decode) out_decode = torch.zeros( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=device, ) out_prefill = torch.zeros_like(out_decode) def run(force_prefill: bool, o_out: torch.Tensor): extend_attention_fwd( q_extend, k_extend, v_extend, o_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask_t, True, mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, config=None, k_scale=1.0, v_scale=1.0, sliding_window_size=sliding_window_size, sinks=sinks, window_kv_offsets=window_kv_offsets, xai_temperature_len=-1, force_v2_prefill=force_prefill, ) run(False, out_decode) run(True, out_prefill) # rtol/atol 为相对/绝对容差(无量纲与输出同量纲)。5e-3 严于原先的 3e-2;若 tiling 导致 fp 序差可改为 1e-2。 torch.testing.assert_close(out_decode, out_prefill, rtol=5e-3, atol=5e-3) def _build_extend_inputs_swa_style( B, H_Q, H_KV, D, device, dtype, max_extend_length, max_prefix_length=512, ): """Layout aligned with sglang test_triton_attention_kernels sliding-window setup. Prefix lengths are uniform in [1, max_prefix_length]; extend lengths are uniform in [1, max_extend_length] with batch 0 fixed to max_extend_length so max(b_seq_len_extend) == max_extend_length. """ b_seq_len_prefix = torch.randint( 1, max_prefix_length + 1, (B,), dtype=torch.int32, device=device ) b_seq_len_extend = torch.randint( 1, max_extend_length + 1, (B,), dtype=torch.int32, device=device ) b_seq_len_extend[0] = int(max_extend_length) b_seq_len = b_seq_len_prefix + b_seq_len_extend b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) kv_indices = torch.zeros( (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device ) for i in range(B): kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i], device=device ) total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() k_buffer = torch.empty( (total_token_num, H_KV, D), dtype=dtype, device=device ).normal_(mean=0.1, std=0.2) v_buffer = torch.empty( (total_token_num, H_KV, D), dtype=dtype, device=device ).normal_(mean=0.1, std=0.2) k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device) q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] extend_start = b_start_loc_extend[i] extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] k_extend[extend_start:extend_end] = k_buffer[ extend_start_in_buffer:extend_end_in_buffer ] v_extend[extend_start:extend_end] = v_buffer[ extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.empty( (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device ).normal_(mean=0.1, std=0.2) b_seq_len_extend = b_seq_len - b_seq_len_prefix max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) return ( q_extend, k_extend, v_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, max_len_extend, ) @pytest.mark.parametrize("window_size", [-1, 32, 127]) @pytest.mark.parametrize("max_extend_length", [1, 3, 8, 32, 256]) def test_extend_attention_v2_sliding_window(window_size, max_extend_length): """v2 + sliding_window_size vs torch reference (sglang-style construction).""" torch.manual_seed(42) device = "cuda" dtype = torch.bfloat16 B, H_Q, H_KV, D = 64, 8, 1, 128 ( q_extend, k_extend, v_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, max_len_extend, ) = _build_extend_inputs_swa_style( B, H_Q, H_KV, D, device, dtype, max_extend_length ) assert max_len_extend == max_extend_length extend_token_num = q_extend.shape[0] o_triton = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) o_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device) extend_attention_fwd( q_extend, k_extend, v_extend, o_triton, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask=None, is_causal=True, mask_indptr=None, max_len_extend=max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, config=None, k_scale=1.2, v_scale=1.2, sliding_window_size=window_size, sinks=None, window_kv_offsets=None, xai_temperature_len=-1, ) extend_attention_fwd_torch_swa( q_extend, k_extend, v_extend, o_torch, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, window_size, k_scale=1.2, v_scale=1.2, sm_scale=None, ) torch.testing.assert_close(o_triton, o_torch, rtol=2e-2, atol=2e-2) if __name__ == "__main__": test_op_fwd(1, 2, 1024, 1024, 256, 0, 256, torch.bfloat16, "normal", False) test_op_fwd(3, 5, 110, 333, 18, 0, 17, torch.float32, "normal", True) test_op_fwd(4, 16, 1024, 16, 512, 64, 512, torch.float16, "normal", True)