Commit 8e1845d2 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Implement tilelang native sparse attention varlen example (#170)

* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats

* [Refactor] Update Thread Binding and Import Statements in TileLang Kernels

- Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
- Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
- Move map_torch_type utility function to tilelang.utils.tensor
- Remove unnecessary imports and improve code organization

* Refactor Native Sparse Attention Example with Enhanced Triton Kernel

- Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation
- Add support for block counts and offsets in the Triton kernel
- Modify kernel grid and computation logic for improved performance
- Update example script to use naive_nsa_simple reference implementation
- Improve type hints and kernel configuration

* Add Native Sparse Attention Examples with Tilelang and Triton Implementations

- Introduce new example scripts for native sparse attention:
  * example_tilelang_nsa_fwd.py: Forward pass implementation using TileLang
  * example_tilelang_nsa_decode.py: Decoding-specific sparse attention implementation
  * example_triton_nsa_fwd.py: Triton-based sparse attention forward pass
- Update reference.py with naive implementations for sparse attention
- Support different sparse attention scenarios including forward pass and inference
- Add comprehensive testing and validation against reference implementations

* lint fix

* Add Variable-Length Native Sparse Attention Examples for TileLang and Triton

- Introduce new example scripts for variable-length native sparse attention:
  * example_tilelang_nsa_fwd_varlen.py: TileLang implementation with variable sequence lengths
  * example_triton_nsa_fwd_varlen.py: Triton implementation with variable sequence lengths
- Update reference.py to support variable-length sparse attention scenarios
- Enhance existing sparse attention implementations to handle variable-length inputs
- Add comprehensive testing and validation for variable-length sparse attention

* Refactor Native Sparse Attention Examples: Code Style and Formatting Improvements

- Standardize function and parameter formatting across NSA example files
- Improve code readability by adjusting indentation and line breaks
- Enhance type hints and parameter alignment
- Remove unnecessary whitespaces and optimize imports
- Maintain consistent code style across TileLang and Triton implementations
parent 166a9585
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from reference import naive_nsa, naive_nsa_simple
from reference import naive_nsa
import tilelang
from tilelang import language as T
import tilelang.testing
......@@ -142,7 +140,8 @@ if __name__ == "__main__":
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
......@@ -156,10 +155,12 @@ if __name__ == "__main__":
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa_simple(
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size)
......
# ruff: noqa
import torch
from typing import Optional, Union
import tilelang
from tilelang import language as T
import tilelang.testing
from fla.ops.common.utils import prepare_token_indices
from reference import naive_nsa
from einops import rearrange
def native_sparse_attention_varlen(batch,
heads,
c_seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [c_seq_len, heads, dim]
kv_shape = [c_seq_len, head_kv, dim]
o_slc_shape = [c_seq_len, heads, dim]
o_swa_shape = [c_seq_len, heads, dim]
lse_slc_shape = [c_seq_len, heads]
lse_swa_shape = [c_seq_len, heads]
block_indices_shape = [c_seq_len, head_kv, selected_blocks]
block_counts_shape = [c_seq_len, head_kv]
offsets_shape = [batch + 1]
token_indices_shape = [c_seq_len, 2]
block_indices_dtype = "int32"
block_counts_dtype = "int32"
offsets_dtype = "int32"
token_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
@T.prim_func
def native_sparse_attention_varlen(
Q: T.Buffer(q_shape, dtype),
K: T.Buffer(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype),
O_slc: T.Buffer(o_slc_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype),
BlockCounts: T.Buffer(block_counts_shape, block_counts_dtype),
Offsets: T.Buffer(offsets_shape, offsets_dtype),
TokenIndices: T.Buffer(token_indices_shape, token_indices_dtype),
):
with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_c, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
i_n, i_t = TokenIndices[i_c, 0], TokenIndices[i_c, 1]
bos = Offsets[i_n]
eos = Offsets[i_n + 1]
current_seq_len = eos - bos
NS = BlockCounts[i_t, i_h]
T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[bos + i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
# Lei: may have some padding issues
# we should learn from mha varlen templates to handle this
T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
return native_sparse_attention_varlen
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
batch = len(offsets) - 1
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
program = native_sparse_attention_varlen(
batch=batch,
heads=HQ,
c_seq_len=C_SEQ_LEN,
dim=K,
is_causal=True,
block_size=block_size,
groups=G,
selected_blocks=S,
)
kernel = tilelang.compile(program)
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
kernel(
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D),
o_slc.view(C_SEQ_LEN, HQ, V),
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32),
token_indices.to(torch.int32))
return o_slc
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
return o_slc.to(q.dtype)
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size,
scale, cu_seqlens)
if window_size > 0:
assert False, "Window size is not supported yet"
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]],
torch.tensor([C_SEQ_LEN], dtype=torch.long)
], 0).cuda().sort()[0]
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(C_SEQ_LEN, device='cuda')
perm_k = torch.randperm(C_SEQ_LEN, device='cuda')
perm_v = torch.randperm(C_SEQ_LEN, device='cuda')
q = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ,
D).clone().requires_grad_(True)
k = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H,
D).clone().requires_grad_(True)
v = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H,
D).clone().requires_grad_(True)
g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda')
for i in range(C_SEQ_LEN):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S]
block_indices[0, i, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from typing import Optional, Union
......@@ -7,12 +5,11 @@ from typing import Optional, Union
import torch
import triton
import triton.language as tl
from einops import rearrange
from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices)
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa, naive_nsa_simple
from fla.ops.common.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
......@@ -70,8 +67,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
if i_t == 6:
print("b_s_slc", b_s_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
......@@ -122,13 +117,12 @@ def parallel_nsa_fwd(
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
import math
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
scale = 1.0 / math.sqrt(K)
G = HQ // H
BS = block_size
WS = window_size
......@@ -176,12 +170,128 @@ def parallel_nsa_fwd(
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
......@@ -194,23 +304,27 @@ if __name__ == "__main__":
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
ref = naive_nsa_simple(
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
tri, _, _, _ = parallel_nsa_fwd(
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
window_size=0,
block_counts=block_counts)
block_counts=block_counts,
)
print("tri", tri)
print("ref", ref)
......
# ruff: noqa
import torch
from typing import Optional, Union
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BS, BV]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf'))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
b_r_swa = tl.exp(b_mp_swa - b_m_swa)
# [G, BS]
b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None])
# [G]
b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1)
# [G, BV]
b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa)
b_mp_swa = b_m_swa
b_o_swa = b_o_swa / b_acc_swa[:, None]
b_m_swa += tl.log(b_acc_swa)
tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty))
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]],
torch.tensor([T], dtype=torch.long)
], 0).cuda().sort()[0]
# offsets.shape is [N+1]
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(T, device='cuda')
perm_k = torch.randperm(T, device='cuda')
perm_v = torch.randperm(T, device='cuda')
q = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
v = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda')
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda')
for i in range(T):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[0, i, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
from typing import Optional
import torch
import torch.nn.functional as F
from typing import Union
from einops import rearrange, repeat
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
head_first: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None) -> torch.Tensor:
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`.
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
Returns:
o (torch.Tensor):
......@@ -49,24 +59,29 @@ def naive_nsa(q: torch.Tensor,
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
block_counts = rearrange(block_counts, 'b h t -> b t h')
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
......@@ -77,11 +92,21 @@ def naive_nsa(q: torch.Tensor,
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, i_b, s_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, block_indices, block_counts))
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
......@@ -89,25 +114,125 @@ def naive_nsa(q: torch.Tensor,
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [1, HQ]
s_i = s_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i, v_i = map(
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill((i_i > i_q) | (c >= s_i),
float('-inf')).softmax(0)
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
if not varlen:
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
......
......@@ -28,7 +28,6 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=32) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
# Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -45,9 +44,7 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# T.copy(Q_shared, Q_local)
# for i, j in T.Parallel(block_M, dim):
# Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
......@@ -293,19 +290,18 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
def test_mha_bwd():
......
......@@ -91,7 +91,7 @@ cdef class CythonKernelWrapper:
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape}, got {tensor_list[buffer_idx].shape}")
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
# Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment