Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
q_dtype = T.float8_e4m3fn
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
qKV_shared = T.alloc_shared([block_N, dim], q_dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.disable_warp_group_reg_alloc()
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared)
T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(qKV_shared, KV_shared)
T.clear(acc_s)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
return main_no_split
def ref_program(q, q_pe, kv, k_pe):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import tilelang.testing
import example_mla_decode
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mla_decode():
example_mla_decode.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
num_split = 1
def flash_split_ref(Q, Q_pe, KV, K_pe):
dim = Q.shape[-1]
pe_dim = Q_pe.shape[-1]
batch = Q.size(0)
nheads = Q.size(1)
block_N = 64
seqlen_kv = KV.size(1)
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
Q_pe_ = Q_pe * scale
KV_ = KV.expand(-1, -1, nheads, -1)
K_pe_ = K_pe.expand(-1, -1, nheads, -1)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum(
"bhd,bkhd->bhk",
Q_,
KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
acc_s += torch.einsum(
"bhd,bkhd->bhk",
Q_pe_,
K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
acc_o *= scores_scale[:, :, None]
acc_s = torch.exp2(acc_s - scores_max[:, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
"bhk,bkhd->bhd",
acc_s_cast,
KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, None]
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :] = acc_o
glogsum[ks, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
def reduce_ref(Q, Q_pe, KV, K_pe, glse, Output_partial):
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0)
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks]
scale = torch.exp2(lse - lse_logsum)
o += Output_partial[:, :, ks, :] * scale[:, :, None]
return o.to(torch.float16)
# ruff: noqa
import torch
import time
import argparse
import tilelang
from tilelang import language as T
import tilelang.testing
from typing import Optional, Union
from einops import rearrange, repeat
import triton
import triton.language as tl
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o_slc: torch.Tensor,
o_swa: Optional[torch.Tensor],
lse_slc: torch.Tensor,
lse_swa: Optional[torch.Tensor],
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
def naive_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`.
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices)
)
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = (
torch.einsum("h d, n h d -> n h", q_i, k_i_slc)
.masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf"))
.softmax(0)
)
if not varlen:
o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b))
attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def get_configs():
import itertools
iter_params = dict(
block_T=[128, 256, 512],
num_stages=[0, 1, 2, 4, 5],
threads=[32, 64, 128, 256, 512],
)
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
)
def tilelang_sparse_attention(
batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32
):
if scale is None:
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(block_T, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
@T.prim_func
def tilelang_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_shared([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)})
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return tilelang_sparse_attention
def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size):
"""Generate random block indices for the benchmark."""
block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda")
for b in range(batch):
for t in range(seq_len):
for h in range(heads):
i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks]
block_indices[b, t, h, : len(i_i)] = i_i
return block_indices.sort(-1)[0]
def benchmark_nsa(
batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False
):
"""Benchmark the TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
tilelang.testing.set_random_seed(0)
torch.random.manual_seed(0)
# Compile the NSA kernel
kernel = tilelang_sparse_attention(
batch=batch_size,
heads=head_query,
seq_len=seq_len,
dim=dim,
is_causal=True,
block_size=block_size,
groups=head_query // heads,
selected_blocks=selected_blocks,
scale=scale,
)
profiler = kernel.get_profiler()
profiler_latency = profiler.do_bench()
print(f"Profiler latency: {profiler_latency} ms")
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32)
# Warmup
for _ in range(warmup):
kernel(Q, K, V, block_indices, out)
# Synchronize before timing
torch.cuda.synchronize()
# Benchmark
start_time = time.time()
for _ in range(iterations):
kernel(Q, K, V, block_indices, out)
torch.cuda.synchronize()
end_time = time.time()
# Calculate metrics
elapsed_time = end_time - start_time
avg_time = elapsed_time / iterations * 1000 # ms
# Calculate FLOPs (approximate for NSA)
# Each token attends to selected_blocks * block_size tokens
# Each attention calculation involves 2*dim FLOPs for QK
# And another 2*dim FLOPs for attention * V
flops_per_token = 4 * dim * selected_blocks * block_size
total_flops = batch_size * seq_len * head_query * flops_per_token
flops_per_sec = total_flops / (elapsed_time / iterations)
tflops = flops_per_sec / 1e12
# Validate result against reference if requested
if validate:
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda")
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2)
if is_valid:
print("Validation: PASSED")
else:
print("Validation: FAILED")
print(f"Max difference: {(ref - out).abs().max().item()}")
# Return benchmark results
return {
"avg_time_ms": avg_time,
"tflops": tflops,
"batch_size": batch_size,
"seq_len": seq_len,
"heads": heads,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size,
}
def benchmark_triton_nsa(
batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False
):
"""Benchmark the Triton-based TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
tilelang.testing.set_random_seed(0)
torch.random.manual_seed(0)
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda")
o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda")
# Warmup
for _ in range(warmup):
out = parallel_nsa_fwd(
q=Q,
k=K,
v=V,
o_slc=o_slc,
o_swa=None,
lse_slc=lse_slc,
lse_swa=None,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale,
)
# Synchronize before timing
torch.cuda.synchronize()
# Benchmark
start_time = time.time()
for _ in range(iterations):
out = parallel_nsa_fwd(
q=Q,
k=K,
v=V,
o_slc=o_slc,
o_swa=None,
lse_slc=lse_slc,
lse_swa=None,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale,
)
torch.cuda.synchronize()
end_time = time.time()
# Calculate metrics
elapsed_time = end_time - start_time
avg_time = elapsed_time / iterations * 1000 # ms
# Calculate FLOPs (approximate for NSA)
flops_per_token = 4 * dim * selected_blocks * block_size
total_flops = batch_size * seq_len * head_query * flops_per_token
flops_per_sec = total_flops / (elapsed_time / iterations)
tflops = flops_per_sec / 1e12
# Validate result against reference if requested
if validate:
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2)
if is_valid:
print("Validation: PASSED")
else:
print("Validation: FAILED")
print(f"Max difference: {(ref - out).abs().max().item()}")
# Return benchmark results
return {
"avg_time_ms": avg_time,
"tflops": tflops,
"batch_size": batch_size,
"seq_len": seq_len,
"heads": heads,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size,
}
def run_benchmark_suite(impl="all"):
"""Run a suite of benchmarks with different configurations."""
# Define configurations to benchmark
configs = [
# Small model config - Note: head_query must be a multiple of heads*16 for Triton
{"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32},
# Medium model config
{"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64},
# Large model config
{"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128},
]
results = []
for config in configs:
print(f"Running benchmark with config: {config}")
if impl in ["all", "tilelang"]:
print("Benchmarking TileLang implementation:")
result = benchmark_nsa(
batch_size=config["batch_size"],
seq_len=config["seq_len"],
heads=config["heads"],
head_query=config["head_query"],
dim=config["dim"],
selected_blocks=config["selected_blocks"],
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False,
)
results.append({"impl": "tilelang", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ["all", "triton"]:
print("Benchmarking Triton implementation:")
result = benchmark_triton_nsa(
batch_size=config["batch_size"],
seq_len=config["seq_len"],
heads=config["heads"],
head_query=config["head_query"],
dim=config["dim"],
selected_blocks=config["selected_blocks"],
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False,
)
results.append({"impl": "triton", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ["all"]:
# Print comparison if both implementations were run
tilelang_result = next(
r
for r in results
if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]
)
triton_result = next(
r
for r in results
if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]
)
speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"]
print(f"Speedup (Triton vs TileLang): {speedup:.2f}x")
print("-" * 50)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark TileLang Sparse Attention")
parser.add_argument("--batch", type=int, default=32, help="Batch size")
parser.add_argument("--seq_len", type=int, default=1024, help="Sequence length")
parser.add_argument("--heads", type=int, default=1, help="Number of heads")
parser.add_argument("--head_query", type=int, default=16, help="Number of query heads")
parser.add_argument("--dim", type=int, default=128, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=32, help="Block size")
parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
parser.add_argument("--validate", action="store_true", help="Validate against reference")
parser.add_argument("--suite", action="store_true", help="Run benchmark suite")
parser.add_argument(
"--impl",
type=str,
default="all",
choices=["tilelang", "triton", "all"],
help="Implementation to benchmark (tilelang, triton, or all)",
)
args = parser.parse_args()
# For Triton impl, ensure head_query is a multiple of heads*16
if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0:
# Adjust head_query to nearest valid value
args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16)
print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation")
if args.suite:
run_benchmark_suite(impl=args.impl)
else:
dtype = torch.float16 if args.dtype == T.float16 else torch.float32
if args.impl in ["tilelang", "all"]:
print("Benchmarking TileLang implementation:")
result = benchmark_nsa(
batch_size=args.batch,
seq_len=args.seq_len,
heads=args.heads,
head_query=args.head_query,
dim=args.dim,
selected_blocks=args.selected_blocks,
block_size=args.block_size,
dtype=dtype,
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate,
)
print("\nBenchmark Results (TileLang):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, "
+ f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, "
+ f"block_size={args.block_size}"
)
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if args.impl in ["triton", "all"]:
print("Benchmarking Triton implementation:")
result = benchmark_triton_nsa(
batch_size=args.batch,
seq_len=args.seq_len,
heads=args.heads,
head_query=args.head_query,
dim=args.dim,
selected_blocks=args.selected_blocks,
block_size=args.block_size,
dtype=dtype,
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate,
)
print("\nBenchmark Results (Triton):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, "
+ f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, "
+ f"block_size={args.block_size}"
)
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
import tilelang
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
)
def tilelang_kernel_fwd(
batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16,
):
from tilelang import language as T
if scale is None:
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
o_slc_shape = [batch, seq_len, heads, dim]
lse_slc_shape = [batch, seq_len, heads]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for k, j in T.Parallel(G, BS):
acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for k in T.Parallel(G):
scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale)
for k, j in T.Parallel(G, BS):
acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for k in T.Parallel(G):
logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k]
T.copy(acc_s, acc_s_cast)
# Rescale
for k, j in T.Parallel(G, BV):
acc_o[k, j] *= scores_scale[k]
# V * softmax(Q * K)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(
O_shared,
O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV],
)
for i in T.Parallel(G):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G])
return native_sparse_attention
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def tilelang_kernel_bwd_dkv(
batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16,
dtype=T.float16,
accum_dtype=T.float32,
):
if scale is None:
sm_scale = (1.0 / dim) ** 0.5
else:
sm_scale = scale
scale = sm_scale * 1.44269504
from tilelang import language as T
B = batch
BS = block_size
G = groups
V = dim
K = dim
BK = tilelang.next_power_of_2(K)
BV = min(128, tilelang.next_power_of_2(dim))
NS = tilelang.cdiv(seq_len, BS)
NV = tilelang.cdiv(V, BV)
heads_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
k_shape = [batch, seq_len, heads_kv, dim]
v_shape = [batch, seq_len, heads_kv, dim]
lse_slc_shape = [batch, seq_len, heads]
delta_slc_shape = [batch, seq_len, heads]
o_shape = [batch, heads, seq_len, dim]
do_slc_shape = [batch, seq_len, heads, dim]
dk_shape = [NV, batch, seq_len, heads_kv, dim]
dv_shape = [batch, seq_len, heads_kv, dim]
block_mask_shape = [batch, seq_len, heads_kv, NS]
num_threads = 32
print("NV", NV, "NS", NS, "B", B, "H", H)
@T.prim_func
def flash_bwd_dkv(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
Q_shared = T.alloc_shared([G, BK], dtype)
qkT = T.alloc_fragment([BS, G], accum_dtype)
qkT_cast = T.alloc_fragment([BS, G], dtype)
dsT = T.alloc_fragment([BS, G], accum_dtype)
dsT_cast = T.alloc_fragment([BS, G], dtype)
lse_shared = T.alloc_shared([G], accum_dtype)
delta = T.alloc_shared([G], accum_dtype)
do = T.alloc_shared([G, BV], dtype)
dv = T.alloc_fragment([BS, BV], accum_dtype)
dk = T.alloc_fragment([BS, BK], accum_dtype)
dq = T.alloc_fragment([BS, G], accum_dtype)
dv_shared = T.alloc_shared([BS, BV], dtype)
dk_shared = T.alloc_shared([BS, BK], dtype)
i_b, i_h = i_bh // H, i_bh % H
T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared)
# [BS, BK]
T.clear(dk)
# [BS, BV]
T.clear(dv)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
loop_st = i_s * BS
loop_ed = seq_len
for i in T.Pipelined(
start=loop_st,
stop=loop_ed,
num_stages=0,
):
b_m_slc = BlockMask[i_b, i, i_h, i_s]
if b_m_slc != 0:
# [G, BK]
T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.clear(qkT)
# [BS, BK] @ [G, BK] -> [BS, G]
T.gemm(
K_shared,
Q_shared,
qkT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
# [G]
T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared)
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)
# [G, BV]
T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do)
T.clear(dsT)
# [BS, BV] @ [G, BV] -> [BS, G]
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(qkT, qkT_cast)
# [BS, G] @ [G, BV] -> [BS, BV]
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G]
T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta)
for i, j in T.Parallel(BS, G):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
# [BS, G] @ [G, BK] -> [BS, BK]
T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK])
return flash_bwd_dkv
def make_dq_layout(dQ):
from tilelang import language as T
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(
dQ.shape,
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2],
)
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def tilelang_kernel_bwd_dqkv(
batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16,
dtype=T.float16,
accum_dtype=T.float32,
):
if scale is None:
sm_scale = (1.0 / dim) ** 0.5
else:
sm_scale = scale
scale = sm_scale * 1.44269504
from tilelang import language as T
B = batch
BS = block_size
G = groups
V = dim
K = dim
BK = tilelang.next_power_of_2(K)
BV = min(128, tilelang.next_power_of_2(dim))
NS = tilelang.cdiv(seq_len, BS)
NV = tilelang.cdiv(V, BV)
heads_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
k_shape = [batch, seq_len, heads_kv, dim]
v_shape = [batch, seq_len, heads_kv, dim]
lse_slc_shape = [batch, seq_len, heads]
delta_slc_shape = [batch, seq_len, heads]
o_shape = [batch, heads, seq_len, dim]
do_slc_shape = [batch, seq_len, heads, dim]
dq_shape = [NV, batch, seq_len, heads, dim]
dk_shape = [NV, batch, seq_len, heads_kv, dim]
dv_shape = [batch, seq_len, heads_kv, dim]
block_mask_shape = [batch, seq_len, heads_kv, NS]
num_threads = 32
@T.prim_func
def flash_bwd_dqkv(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DQ: T.Tensor(dq_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
dsT_shared = T.alloc_shared([BS, G], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
Q_shared = T.alloc_shared([G, BK], dtype)
qkT = T.alloc_fragment([BS, G], accum_dtype)
qkT_cast = T.alloc_fragment([BS, G], dtype)
dsT = T.alloc_fragment([BS, G], accum_dtype)
dsT_cast = T.alloc_fragment([BS, G], dtype)
lse_shared = T.alloc_shared([G], accum_dtype)
delta = T.alloc_shared([G], accum_dtype)
do = T.alloc_shared([G, BV], dtype)
dv = T.alloc_fragment([BS, BV], accum_dtype)
dk = T.alloc_fragment([BS, BK], accum_dtype)
dq = T.alloc_fragment([G, BK], accum_dtype)
dv_shared = T.alloc_shared([BS, BV], dtype)
dk_shared = T.alloc_shared([BS, BK], dtype)
i_b, i_h = i_bh // H, i_bh % H
T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared)
# [BS, BK]
T.clear(dk)
# [BS, BV]
T.clear(dv)
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
loop_st = i_s * BS
loop_ed = seq_len
for i in T.Pipelined(
start=loop_st,
stop=loop_ed,
num_stages=0,
):
b_m_slc = BlockMask[i_b, i, i_h, i_s]
if b_m_slc != 0:
# [G, BK]
T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.clear(qkT)
# [BS, BK] @ [G, BK] -> [BS, G]
T.gemm(
K_shared,
Q_shared,
qkT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
# [G]
T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared)
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)
# [G, BV]
T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do)
T.clear(dsT)
# [BS, BV] @ [G, BV] -> [BS, G]
T.gemm(
V_shared,
do,
dsT,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(qkT, qkT_cast)
# [BS, G] @ [G, BV] -> [BS, BV]
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G]
T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta)
for _i, _j in T.Parallel(BS, G):
dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale
# [BS, G] @ [G, BK] -> [BS, BK]
T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
# [BS, G] * [BS, BK] -> [G, BK]
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for _i, _j in T.Parallel(G, BK):
T.atomic_add(DQ[i_v, i_b, i, i_h * G + _i, _j], dq[_i, _j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK])
return flash_bwd_dqkv
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def tilelang_kernel_preprocess(
batch,
heads,
seq_len,
dim,
dtype=T.float16,
accum_dtype=T.float32,
blk=32,
):
from tilelang import language as T
shape = [batch, seq_len, heads, dim]
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx])
return flash_bwd_prep
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def tilelang_kernel_block_mask(
batch,
heads,
seq_len,
selected_blocks,
block_size,
dtype=T.int32,
):
from tilelang import language as T
block_indices_shape = [batch, seq_len, heads, selected_blocks]
block_counts_shape = [batch, seq_len, heads]
S = selected_blocks
BS = block_size
NS = tilelang.cdiv(seq_len, BS)
block_mask_shape = [batch, seq_len, heads, NS]
USE_BLOCK_COUNTS = block_counts is not None
@T.prim_func
def flash_bwd_block_mask(
BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore
BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore
):
with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz):
i_t, i_b, i_hs = bx, by, bz
i_h, i_s = i_hs // S, i_hs % S
b_i = BlockIndices[i_b, i_t, i_h, i_s]
if USE_BLOCK_COUNTS:
b_m = b_i * BS <= i_t and i_s < BlockCounts[i_b, i_t, i_h].astype(i_s.dtype)
BlockMask[i_b, i_t, i_h, i_s] = b_m
else:
b_m = b_i * BS <= i_t
BlockMask[i_b, i_t, i_h, i_s] = b_m
return flash_bwd_block_mask
def parallel_nsa_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o_slc: torch.Tensor,
lse_slc: torch.Tensor,
do_slc: torch.Tensor,
o_swa: torch.Tensor,
lse_swa: torch.Tensor,
do_swa: torch.Tensor,
block_indices: torch.Tensor,
block_counts: Union[torch.LongTensor, int],
block_size: int = 64,
window_size: int = 0,
scale: float = None,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
BK = triton.next_power_of_2(K)
BV = min(128, triton.next_power_of_2(v.shape[-1]))
NV = triton.cdiv(V, BV)
assert window_size == 0, "Window size is not supported yet"
delta_slc = tilelang_kernel_preprocess(B, HQ, T, K)(o_slc, do_slc)
dq = torch.zeros(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device)
dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool)
fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv(
batch=B,
heads=HQ,
seq_len=T,
dim=K,
is_causal=True,
block_size=BS,
groups=G,
selected_blocks=S,
scale=scale,
)
fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32))
dq = dq.sum(0)
dk = dk.sum(0)
return dq, dk, dv
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(
ctx,
q,
k,
v,
block_indices,
block_counts,
block_size,
window_size,
scale,
offsets,
):
ctx.dtype = q.dtype
assert offsets is None, "Offsets are not supported yet"
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
B, SEQLEN, HQ, D = q.shape
H = k.shape[2]
G = HQ // H
S = block_indices.shape[-1]
V = v.shape[-1]
kernel = tilelang_kernel_fwd(
batch=B,
heads=HQ,
seq_len=SEQLEN,
dim=D,
is_causal=True,
scale=scale,
block_size=block_size,
groups=G,
selected_blocks=S,
)
o_slc = torch.empty(B, SEQLEN, HQ, D, dtype=v.dtype, device=q.device)
lse_slc = torch.empty(B, SEQLEN, HQ, dtype=torch.float, device=q.device)
kernel(q, k, v, block_indices.to(torch.int32), o_slc, lse_slc)
ctx.save_for_backward(q, k, v, o_slc, lse_slc)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), lse_slc.to(torch.float)
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do_slc, do_swa):
q, k, v, o_slc, lse_slc = ctx.saved_tensors
dq, dk, dv = parallel_nsa_bwd(
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=None,
lse_slc=lse_slc,
lse_swa=None,
do_slc=do_slc,
do_swa=do_swa,
block_indices=ctx.block_indices,
block_counts=ctx.block_counts,
block_size=ctx.block_size,
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices,
)
return (
dq.to(q),
dk.to(k),
dv.to(v),
None,
None,
None,
None,
None,
None,
None,
None,
)
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, SEQLEN, HQ, K]` if `head_first=False` else `[B, HQ, SEQLEN, K]`.
k (torch.Tensor):
keys of shape `[B, SEQLEN, H, K]` if `head_first=False` else `[B, H, SEQLEN, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, SEQLEN, H, V]` if `head_first=False` else `[B, H, SEQLEN, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`.
g_swa (torch.Tensor):
Gate score for sliding attention of shape `[B, SEQLEN, HQ]` if `head_first=False` else `[B, HQ, SEQLEN]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, SEQLEN, H, S]` if `head_first=False` else `[B, H, SEQLEN, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, SEQLEN, H]` if `head_first=True` else `[B, SEQLEN, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 1, 32, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
ref.backward(do)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
block_counts=block_counts,
)
tri.backward(do)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
# assert_close(" o", ref, tri, 0.004)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from reference import naive_nsa_simple_inference
import tilelang
from tilelang import language as T
import tilelang.testing
tilelang.testing.set_random_seed(42)
# TODO(lei): workaround, as threads is not divisible by warp group size,
# auto warp specialization may have some bugs.
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def native_sparse_attention(
batch,
heads,
seq_len, # Length of K/V sequences (context window size)
dim, # Embedding dimension per head
scale=None,
block_size=64, # Tile size for attention computation
groups=1, # Grouped query attention (GQA) groups
selected_blocks=16, # Number of blocks to select per attention head
):
if scale is None:
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1)a
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim]
K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
V: T.Tensor(kv_shape, dtype), # Same shape as K
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices
Output: T.Tensor(q_shape, dtype), # Output attention tensor
):
with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
# Shared memory allocations for tile storage
Q_shared = T.alloc_shared([G, BK], dtype) # Current query block
K_shared = T.alloc_shared([BS, BK], dtype) # Current key block
V_shared = T.alloc_shared([BS, BV], dtype) # Current value block
O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator
# Attention computation buffers
acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores
acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax
acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_v, i_bh = by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
# Copy Q for the single position
T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# Main attention computation loop over selected blocks
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset
if i_s >= 0: # Skip invalid/padding blocks
# Load current key block to shared memory
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
# Compute QK^T attention scores
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Online softmax with numerical stability
# 1. Compute max for scaling
# 2. Compute exponentials and sum
# 3. Maintain running logsum for normalization
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Accumulate attention-weighted values
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Final normalization and output
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i] # Normalize by logsum
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0
return native_sparse_attention
def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H
SEQ_LEN_Q = 1
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
dim=D,
block_size=block_size,
groups=HQ // H,
selected_blocks=S,
)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda")
DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(SEQ_LEN_Q):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda")
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa_simple_inference(
q=Q,
k=K,
v=V,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
# ruff: noqa
import torch
from reference import naive_nsa
import tilelang
from tilelang import language as T
import tilelang.testing
tilelang.testing.set_random_seed(0)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
if scale is None:
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 2
threads = 32
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return native_sparse_attention
def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
kernel = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
dim=D,
is_causal=True,
block_size=block_size,
groups=HQ // H,
selected_blocks=S,
scale=scale,
)
print(kernel.get_kernel_source())
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda")
for b in range(B):
for t in range(SEQ_LEN):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, : len(i_i)] = i_i
block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item()
block_indices = block_indices.sort(-1)[0]
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import tilelang
from tilelang import language as T
import tilelang.testing
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from reference import naive_nsa
from einops import rearrange
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
)
def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
if scale is None:
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [c_seq_len, heads, dim]
kv_shape = [c_seq_len, head_kv, dim]
o_slc_shape = [c_seq_len, heads, dim]
o_swa_shape = [c_seq_len, heads, dim]
lse_slc_shape = [c_seq_len, heads]
lse_swa_shape = [c_seq_len, heads]
block_indices_shape = [c_seq_len, head_kv, selected_blocks]
block_counts_shape = [c_seq_len, head_kv]
offsets_shape = [batch + 1]
token_indices_shape = [c_seq_len, 2]
block_indices_dtype = T.int32
block_counts_dtype = T.int32
offsets_dtype = T.int32
token_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
@T.prim_func
def native_sparse_attention_varlen(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype),
Offsets: T.Tensor(offsets_shape, offsets_dtype),
TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype),
):
with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_c, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
i_n, i_t = TokenIndices[i_c, 0], TokenIndices[i_c, 1]
bos = Offsets[i_n]
eos = Offsets[i_n + 1]
current_seq_len = eos - bos
NS = BlockCounts[i_t, i_h]
T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[bos + i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
# Lei: may have some padding issues
# we should learn from mha varlen templates to handle this
T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return native_sparse_attention_varlen
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
batch = len(offsets) - 1
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
kernel = native_sparse_attention_varlen(
batch=batch,
heads=HQ,
c_seq_len=C_SEQ_LEN,
dim=K,
is_causal=True,
block_size=block_size,
groups=G,
selected_blocks=S,
)
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
kernel(
q.view(C_SEQ_LEN, HQ, D),
k.view(C_SEQ_LEN, H, D),
v.view(C_SEQ_LEN, H, D),
o_slc.view(C_SEQ_LEN, HQ, V),
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
block_counts.to(torch.int32).view(C_SEQ_LEN, H),
offsets.to(torch.int32),
token_indices.to(torch.int32),
)
return o_slc
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices,
)
return o_slc.to(q.dtype)
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
assert False, "Window size is not supported yet"
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = (
torch.cat(
[
torch.tensor([0], dtype=torch.long),
torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]],
torch.tensor([C_SEQ_LEN], dtype=torch.long),
],
0,
)
.cuda()
.sort()[0]
)
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(C_SEQ_LEN, device="cuda")
perm_k = torch.randperm(C_SEQ_LEN, device="cuda")
perm_v = torch.randperm(C_SEQ_LEN, device="cuda")
q = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, HQ, D)
.clone()
.requires_grad_(True)
)
k = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, H, D)
.clone()
.requires_grad_(True)
)
v = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, H, D)
.clone()
.requires_grad_(True)
)
g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda")
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda")
for i in range(C_SEQ_LEN):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S]
block_indices[0, i, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda")
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets,
)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets,
)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do_slc, do_swa):
q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors
dq, dk, dv = parallel_nsa_bwd(
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
do_slc=do_slc,
do_swa=do_swa,
block_indices=ctx.block_indices,
block_counts=ctx.block_counts,
block_size=ctx.block_size,
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices,
)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BS", "BK", "BV"],
)
@triton.jit(do_not_specialize=["T"])
def parallel_nsa_bwd_kernel_dkv(
q,
k,
v,
lse_slc,
lse_swa,
delta_slc,
delta_swa,
do_slc,
do_swa,
dk,
dv,
block_mask,
offsets,
chunk_indices,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
M: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
):
i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
# [BS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.zeros([BS, BK], dtype=tl.float32)
# [BS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dv = tl.zeros([BS, BV], dtype=tl.float32)
for i in range(i_s * BS, T):
b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s)
if b_m_slc:
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1))
# [G]
b_lse_slc = tl.load(p_lse_slc)
b_delta_slc = tl.load(p_delta_slc)
# [BS, G]
b_s_slc = tl.dot(b_k, tl.trans(b_q))
b_p_slc = tl.exp(b_s_slc - b_lse_slc[None, :])
b_p_slc = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p_slc, 0)
# [BS, G] @ [G, BV] -> [BS, BV]
b_dv += tl.dot(b_p_slc.to(b_do_slc.dtype), b_do_slc)
# [BS, BV] @ [BV, G] -> [BS, G]
b_dp_slc = tl.dot(b_v, tl.trans(b_do_slc))
# [BS, G]
b_ds_slc = b_p_slc * (b_dp_slc - b_delta_slc[None, :])
# [BS, G] @ [G, BK] -> [BS, BK]
b_dk += tl.dot(b_ds_slc.to(b_q.dtype), b_q)
if WS > 0:
o_s = i_s * BS + tl.arange(0, BS)
if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1):
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1))
# [G]
b_lse_swa = tl.load(p_lse_swa)
b_delta_swa = tl.load(p_delta_swa)
# [BS, G]
b_s_swa = tl.dot(b_k, tl.trans(b_q))
b_p_swa = tl.exp(b_s_swa - b_lse_swa[None, :])
b_p_swa = tl.where((i >= o_s and (i - WS) < o_s)[:, None], b_p_swa, 0)
# [BS, G] @ [G, BV] -> [BS, BV]
b_dv += tl.dot(b_p_swa.to(b_do_swa.dtype), b_do_swa)
# [BS, BV] @ [BV, G] -> [BS, G]
b_dp_swa = tl.dot(b_v, tl.trans(b_do_swa))
# [BS, G]
b_ds_swa = b_p_swa * (b_dp_swa - b_delta_swa[None, :])
# [BS, G] @ [G, BK] -> [BS, BK]
b_dk += tl.dot(b_ds_swa.to(b_q.dtype), b_q)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)})
@triton.jit
def parallel_nsa_kernel_mask(
block_indices,
block_counts,
block_mask,
T: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
NS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_s = i_hs // S, i_hs % S
b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s)
if USE_BLOCK_COUNTS:
b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h)
else:
b_m = b_i * BS <= i_t
if b_i < NS and b_i >= 0:
tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty))
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BS", "BK", "BV"],
)
@triton.jit(do_not_specialize=["T"])
def parallel_nsa_bwd_kernel_dq(
q,
k,
v,
lse_slc,
delta_slc,
do_slc,
lse_swa,
delta_swa,
do_swa,
dq,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
q += (bos + i_t) * HQ * K
do_slc += (bos + i_t) * HQ * V
lse_slc += (bos + i_t) * HQ
delta_slc += (bos + i_t) * HQ
if WS > 0:
do_swa += (bos + i_t) * HQ * V
lse_swa += (bos + i_t) * HQ
delta_swa += (bos + i_t) * HQ
dq += (i_v * B * T + bos + i_t) * HQ * K
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_slc = tl.make_block_ptr(do_slc, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + i_h * G + tl.arange(0, G)
p_delta_slc = delta_slc + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1))
# [G]
b_lse_slc = tl.load(p_lse_slc)
b_delta_slc = tl.load(p_delta_slc)
# [G, BK]
b_dq_slc = tl.zeros([G, BK], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BV, BS]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_p_slc = tl.exp(b_s_slc - b_lse_slc[:, None])
b_p_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_slc, 0)
# [G, BV] @ [BV, BS] -> [G, BS]
b_dp_slc = tl.dot(b_do_slc, b_v_slc)
b_ds_slc = b_p_slc * (b_dp_slc.to(tl.float32) - b_delta_slc[:, None])
# [G, BS] @ [BS, BK] -> [G, BK]
b_dq_slc += tl.dot(b_ds_slc.to(b_k_slc.dtype), tl.trans(b_k_slc))
b_dq_slc *= scale
if WS > 0:
p_do_swa = tl.make_block_ptr(do_swa, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + i_h * G + tl.arange(0, G)
p_delta_swa = delta_swa + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1))
# [G]
b_lse_swa = tl.load(p_lse_swa)
b_delta_swa = tl.load(p_delta_swa)
# [G, BK]
b_dq_swa = tl.zeros([G, BK], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BV, BS]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_p_swa = tl.exp(b_s_swa - b_lse_swa[:, None])
b_p_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_swa, 0)
# [G, BV] @ [BV, BS] -> [G, BS]
b_dp_swa = tl.dot(b_do_swa, b_v_swa)
b_ds_swa = b_p_swa * (b_dp_swa.to(tl.float32) - b_delta_swa[:, None])
# [G, BS] @ [BS, BK] -> [G, BK]
b_dq_swa += tl.dot(b_ds_swa.to(b_k_swa.dtype), tl.trans(b_k_swa))
b_dq_swa *= scale
if WS == 0:
tl.store(p_dq, b_dq_slc.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
else:
tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BS, BV]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf"))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
b_r_swa = tl.exp(b_mp_swa - b_m_swa)
# [G, BS]
b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None])
# [G]
b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1)
# [G, BV]
b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa)
b_mp_swa = b_m_swa
b_o_swa = b_o_swa / b_acc_swa[:, None]
b_m_swa += tl.log(b_acc_swa)
tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty))
@triton.jit
def parallel_nsa_bwd_kernel_preprocess(o, do, delta, B: tl.constexpr, V: tl.constexpr):
i_n = tl.program_id(0)
o_d = tl.arange(0, B)
m_d = o_d < V
b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
b_delta = tl.sum(b_o * b_do)
tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
def parallel_nsa_block_mask(
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
offsets: torch.LongTensor,
block_size: int,
):
B, T, H, S = block_indices.shape
BS = block_size
if offsets is not None:
NS = triton.cdiv(prepare_lens(offsets).max().item(), BS)
else:
NS = triton.cdiv(T, BS)
block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device)
parallel_nsa_kernel_mask[(T, B, H * S)](
block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS
)
return block_mask
def parallel_nsa_bwd_preprocess(o: torch.Tensor, do: torch.Tensor):
V = o.shape[-1]
delta = torch.empty_like(o[..., 0], dtype=torch.float32)
parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)](
o=o,
do=do,
delta=delta,
B=triton.next_power_of_2(V),
V=V,
)
return delta
def parallel_nsa_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o_slc: torch.Tensor,
lse_slc: torch.Tensor,
do_slc: torch.Tensor,
o_swa: torch.Tensor,
lse_swa: torch.Tensor,
do_swa: torch.Tensor,
block_indices: torch.Tensor,
block_counts: Union[torch.LongTensor, int],
block_size: int = 64,
window_size: int = 0,
scale: float = None,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
BK = triton.next_power_of_2(K)
BV = min(128, triton.next_power_of_2(v.shape[-1]))
NV = triton.cdiv(V, BV)
delta_slc = parallel_nsa_bwd_preprocess(o_slc, do_slc)
delta_swa = parallel_nsa_bwd_preprocess(o_swa, do_swa) if window_size > 0 else None
dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
grid = (T, NV, B * H)
parallel_nsa_bwd_kernel_dq[grid](
q=q,
k=k,
v=v,
lse_slc=lse_slc,
delta_slc=delta_slc,
do_slc=do_slc,
lse_swa=lse_swa,
delta_swa=delta_swa,
do_swa=do_swa,
dq=dq,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
scale=scale,
T=T,
B=B,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
dq = dq.sum(0)
if offsets is not None:
chunk_indices = prepare_chunk_indices(offsets, BS)
NS = len(chunk_indices)
else:
chunk_indices = None
NS = triton.cdiv(T, BS)
# [B, T, H, M]
block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size)
dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
grid = (NV, NS, B * H)
parallel_nsa_bwd_kernel_dkv[grid](
q=q,
k=k,
v=v,
lse_slc=lse_slc,
lse_swa=lse_swa,
delta_slc=delta_slc,
delta_swa=delta_swa,
do_slc=do_slc,
do_swa=do_swa,
dk=dk,
dv=dv,
block_mask=block_mask,
offsets=offsets,
chunk_indices=chunk_indices,
scale=scale,
T=T,
B=B,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
M=block_mask.shape[-1],
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
dk = dk.sum(0)
return dq, dk, dv
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do_slc, do_swa):
q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors
dq, dk, dv = parallel_nsa_bwd(
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
do_slc=do_slc,
do_swa=do_swa,
block_indices=ctx.block_indices,
block_counts=ctx.block_counts,
block_size=ctx.block_size,
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices,
)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
ref.backward(do)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
block_counts=block_counts,
)
print("tri", tri)
print("ref", ref)
tri.backward(do)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
# assert_close(" o", ref, tri, 0.004)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
block_counts=block_counts,
)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BS, BV]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf"))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
b_r_swa = tl.exp(b_mp_swa - b_m_swa)
# [G, BS]
b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None])
# [G]
b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1)
# [G, BV]
b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa)
b_mp_swa = b_m_swa
b_o_swa = b_o_swa / b_acc_swa[:, None]
b_m_swa += tl.log(b_acc_swa)
tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty))
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = (
torch.cat(
[torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)],
0,
)
.cuda()
.sort()[0]
)
# offsets.shape is [N+1]
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(T, device="cuda")
perm_k = torch.randperm(T, device="cuda")
perm_v = torch.randperm(T, device="cuda")
q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda")
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda")
for i in range(T):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[0, i, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda")
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets,
)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets,
)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# ruff: noqa
from typing import Optional
import torch
from typing import Union
from einops import rearrange, repeat
def naive_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`.
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices)
)
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = (
torch.einsum("h d, n h d -> n h", q_i, k_i_slc)
.masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf"))
.softmax(0)
)
if not varlen:
o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b))
attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def naive_nsa_simple(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1] ** -0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[i_q]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[i_q]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum("h d, n h d -> n h", q_i, k_i)
attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf"))
attn = torch.softmax(attn, dim=0)
o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i)
return o.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1] ** -0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum("h d, n h d -> n h", q_i, k_i)
attn = attn.masked_fill((c >= s_i), float("-inf"))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i)
return o.to(dtype)
git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e
\ No newline at end of file
# ruff: noqa
import tilelang.testing
from example_tilelang_nsa_fwd import main as main_fwd
from example_tilelang_nsa_decode import main as main_fwd_decode
def test_example_tilelang_nsa_fwd():
main_fwd()
def test_example_tilelang_nsa_fwd_decode():
main_fwd_decode()
if __name__ == "__main__":
tilelang.testing.main()
## Directory Structure
```
deepseek_v32/
├── README.md # This file
├── figures/ # Figures and diagrams
├── inference/ # Inference implementation folder
├── fp8_lighting_indexer.py # FP8 lighting indexer
├── sparse_mla_bwd.py # Sparse MLA backward implementation
├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
├── topk_selector.py # Top-k selector implementation
```
## File Descriptions
### Architecture Overview
![DeepSeek V3.2 Architecture](./figures/v32_arch.png)
The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
### Lightning Indexer
Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.
The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:
```python
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
```
After the matmul, we apply ReLU and aggregate across heads with learned weights:
```python
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, h_i] = (
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
```
The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:
```python
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
```
The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.
### Top-k Selector
The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.
The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:
```python
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s*BLOCK_SIZE+tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
```
The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:
```python
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
```
Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:
```python
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
```
Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.
### Sparse MLA Forward
The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.
Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:
```python
# Dense MLA: iterate over full sequence
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
# ... compute attention over this block
```
Sparse MLA only loads KV positions selected by the top-k selector:
```python
# Sparse MLA: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
# ... compute attention over selected tokens
```
This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:
```python
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
```
Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.
### Sparse MLA Forward (Pipelined)
The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.
The key difference is splitting the warp groups into specialized roles:
```python
if tx < 128:
# Consumer 0: computes left half of output (D//2 dimensions)
# Handles QK matmul, softmax, and PV for left half
elif tx >= 128 and tx < 256:
# Consumer 1: computes right half of output (D//2 dimensions)
# Only does PV matmul for right half
elif tx >= 256:
# Producer: loads KV data from global memory
# Uses async copy with barriers to feed consumers
```
The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:
```python
# Producer alternates between two buffers
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 0
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 1
T.cp_async_barrier_noinc(bar_k_1_ready[0])
```
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
### Sparse MLA Backward
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
The backward pass consists of three main stages:
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
```python
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
```
**2. Main Backward Computation**: Computes gradients through sparse attention:
```python
# Sparse MLA backward: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
# Load KV data for selected indices
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i]
# Recompute attention scores for backward
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
# Apply softmax gradient: dP = P * (dP_raw - Delta)
for h_i, bi_i in T.Parallel(padded_H, BI):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
```
The key gradient computations are:
- **dQ = dP @ K** (query gradients)
- **dK = dP^T @ Q** (key gradients)
- **dV = P^T @ dO** (value gradients)
**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation:
```python
# Atomically update dKV at selected indices
for bi_i, d_i in T.Parallel(BI // split_store, D // 4):
T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4])
```
**Performance**: The sparse MLA backward achieves excellent performance:
- **H800 SXM**: ~100 TFlops
- **H200 SXM**: ~115 TFlops
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment