# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ruff: noqa import torch from typing import Optional import torch import triton import triton.language as tl from einops import rearrange from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices) from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous from reference import naive_nsa @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16]], key=['BS', 'BK', 'BV'], ) @triton.jit def parallel_nsa_fwd_kernel( q, k, v, o, lse, scale, block_indices, T, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H bos, eos = i_b * T, i_b * T + T k += (bos * H + i_h) * K v += (bos * H + i_h) * V block_indices += (bos + i_t) * H * S + i_h * S NS = S p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) p_o = tl.make_block_ptr(o + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # [G, BV] b_o = tl.zeros([G, BV], dtype=tl.float32) b_m = tl.full([G], float('-inf'), dtype=tl.float32) b_acc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS if i_s <= i_t: p_k = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) # [BK, BS] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BS, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [G, BS] b_s = tl.dot(b_q, b_k) b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) # [G] b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m b_r = tl.exp(b_mp - b_m) # [G, BS] b_p = tl.exp(b_s - b_m[:, None]) # [G] b_acc = b_acc * b_r + tl.sum(b_p, 1) # [G, BV] b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) b_mp = b_m b_o = b_o / b_acc[:, None] b_m += tl.log(b_acc) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) def parallel_nsa_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.Tensor, block_size: int, scale: float, ): B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] HQ = q.shape[2] G = HQ // H BS = block_size if torch.cuda.get_device_capability()[0] >= 9: BK = min(256, triton.next_power_of_2(K)) BV = min(256, triton.next_power_of_2(V)) else: BK = min(128, triton.next_power_of_2(K)) BV = min(128, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" grid = (NV, T, B * H) o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) lse = torch.empty(B, T, HQ, dtype=torch.float32, device=q.device) print("grid", grid) parallel_nsa_fwd_kernel[grid]( q=q, k=k, v=v, o=o, lse=lse, scale=scale, block_indices=block_indices, H=H, HQ=HQ, G=G, T=T, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV, ) return o, lse class ParallelNSAFunction(torch.autograd.Function): @staticmethod @contiguous @autocast_custom_fwd def forward(ctx, q, k, v, block_indices, block_size, scale, offsets): ctx.dtype = q.dtype # 2-d sequence indices denoting the offsets of tokens in each sequence # for example, if the passed `offsets` is [0, 2, 6], # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None o, lse = parallel_nsa_fwd( q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_size = block_size ctx.scale = scale return o.to(q.dtype) def parallel_nsa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.LongTensor, block_size: int = 64, scale: Optional[float] = None, cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False) -> torch.Tensor: r""" Args: q (torch.Tensor): queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. k (torch.Tensor): keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. block_indices (torch.LongTensor): Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. `S` is the number of selected blocks for each query token, which is set to 16 in the paper. block_size (int): Selected block size. Default: 64. scale (Optional[int]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. Returns: o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ if scale is None: scale = k.shape[-1]**-0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, block_indices)) o = ParallelNSAFunction.apply(q, k, v, block_indices, block_size, scale, cu_seqlens) if head_first: o = rearrange(o, 'b t h d -> b h t d') return o if __name__ == "__main__": B, T, H, HQ, D, S, block_size, dtype, scale = 1, 64, 1, 16, 32, 1, 64, torch.float16, 0.1 q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda') block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda') for b in range(B): for t in range(T): for h in range(H): i_i = torch.randperm(max(1, (t // block_size)))[:S] block_indices[b, t, h, :len(i_i)] = i_i block_indices = block_indices.sort(-1)[0] block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') ref = naive_nsa( q=q, k=k, v=v, block_indices=block_indices, block_counts=block_counts, block_size=block_size, scale=scale) # print(ref) tri = parallel_nsa( q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) # print(tri) torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) # import flash_attn # # gqa # o_gqa = flash_attn.flash_attn_func( # q, # k, # v, # softmax_scale=scale, # ) # print(o_gqa) # torch.testing.assert_close(o_gqa, tri, atol=1e-2, rtol=1e-2)