Commit 69f35439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Implement NSA Decode tilelang exampls (#168)

* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats

* [Refactor] Update Thread Binding and Import Statements in TileLang Kernels

- Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
- Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
- Move map_torch_type utility function to tilelang.utils.tensor
- Remove unnecessary imports and improve code organization

* Refactor Native Sparse Attention Example with Enhanced Triton Kernel

- Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation
- Add support for block counts and offsets in the Triton kernel
- Modify kernel grid and computation logic for improved performance
- Update example script to use naive_nsa_simple reference implementation
- Improve type hints and kernel configuration

* Add Native Sparse Attention Examples with Tilelang and Triton Implementations

- Introduce new example scripts for native sparse attention:
  * example_tilelang_nsa_fwd.py: Forward pass implementation using TileLang
  * example_tilelang_nsa_decode.py: Decoding-specific sparse attention implementation
  * example_triton_nsa_fwd.py: Triton-based sparse attention forward pass
- Update reference.py with naive implementations for sparse attention
- Support different sparse attention scenarios including forward pass and inference
- Add comprehensive testing and validation against reference implementations

* lint fix
parent b6c48453
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
from reference import naive_nsa_simple_inference
import tilelang
from tilelang import language as T
import tilelang.testing
tilelang.testing.set_random_seed(42)
def native_sparse_attention(
batch,
heads,
seq_len, # Length of K/V sequences (context window size)
dim, # Embedding dimension per head
scale=None,
block_size=64, # Tile size for attention computation
groups=1, # Grouped query attention (GQA) groups
selected_blocks=16 # Number of blocks to select per attention head
):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1)
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
threads = 32
@T.prim_func
def native_sparse_attention(
Q: T.Buffer(q_shape, dtype), # [batch, 1, heads, dim]
K: T.Buffer(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
V: T.Buffer(kv_shape, dtype), # Same shape as K
BlockIndices: T.Buffer(block_indices_shape,
block_indices_dtype), # Selected block indices
Output: T.Buffer(q_shape, dtype), # Output attention tensor
):
with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
# Shared memory allocations for tile storage
Q_shared = T.alloc_shared([G, BK], dtype) # Current query block
K_shared = T.alloc_shared([BS, BK], dtype) # Current key block
V_shared = T.alloc_shared([BS, BV], dtype) # Current value block
O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator
# Attention computation buffers
acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores
acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax
acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_v, i_bh = by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
# Copy Q for the single position
T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# Main attention computation loop over selected blocks
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset
if i_s >= 0: # Skip invalid/padding blocks
# Load current key block to shared memory
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
# Compute QK^T attention scores
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Online softmax with numerical stability
# 1. Compute max for scaling
# 2. Compute exponentials and sum
# 3. Maintain running logsum for normalization
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Accumulate attention-weighted values
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Final normalization and output
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i] # Normalize by logsum
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G,
i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0
return native_sparse_attention
if __name__ == "__main__":
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H
SEQ_LEN_Q = 1
program = native_sparse_attention(
batch=B,
heads=HQ,
seq_len=SEQ_LEN,
dim=D,
block_size=block_size,
groups=HQ // H,
selected_blocks=S,
)
kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda')
DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
for b in range(B):
for t in range(SEQ_LEN_Q):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda')
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa_simple_inference(
q=Q,
k=K,
v=V,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import torch import torch
from reference import naive_nsa from reference import naive_nsa, naive_nsa_simple
import tilelang import tilelang
from tilelang import language as T from tilelang import language as T
import tilelang.testing import tilelang.testing
...@@ -16,6 +16,7 @@ def native_sparse_attention(batch, ...@@ -16,6 +16,7 @@ def native_sparse_attention(batch,
dim, dim,
is_causal, is_causal,
scale=None, scale=None,
block_size=64,
groups=1, groups=1,
selected_blocks=16): selected_blocks=16):
if scale is None: if scale is None:
...@@ -27,116 +28,104 @@ def native_sparse_attention(batch, ...@@ -27,116 +28,104 @@ def native_sparse_attention(batch,
block_indices_dtype = "int32" block_indices_dtype = "int32"
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
block_S = 64 block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks S = selected_blocks
NS = S
G = groups G = groups
BS = block_S BS = block_S
BK = BV = block_T BK = BV = block_T
num_stages = 0 num_stages = 0
threads = 32 threads = 32
def kernel_func(block_S, block_T, num_stages, threads): @T.prim_func
def native_sparse_attention(
@T.prim_func Q: T.Buffer(q_shape, dtype),
def main( K: T.Buffer(kv_shape, dtype),
Q: T.Buffer(q_shape, dtype), V: T.Buffer(kv_shape, dtype),
K: T.Buffer(kv_shape, dtype), BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype),
V: T.Buffer(kv_shape, dtype), Output: T.Buffer(q_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype), ):
Output: T.Buffer(q_shape, dtype), with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
): Q_shared = T.alloc_shared([G, BK], dtype)
with T.Kernel( K_shared = T.alloc_shared([BS, BK], dtype)
dim // block_T, seq_len, batch * head_kv, threads=threads) as (bx, by, bz): V_shared = T.alloc_shared([BS, BV], dtype)
Q_shared = T.alloc_shared([G, BK], dtype) O_shared = T.alloc_shared([G, BV], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype) acc_s = T.alloc_fragment([G, BS], accum_dtype)
O_shared = T.alloc_shared([G, BV], dtype) acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype) scores_max = T.alloc_fragment([G], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype) scores_max_prev = T.alloc_fragment([G], accum_dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype) scores_scale = T.alloc_fragment([G], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype) scores_sum = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype) logsum = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype) i_t, i_v, i_bh = bx, by, bz
logsum = T.alloc_fragment([G], accum_dtype) i_b, i_h = i_bh // head_kv, i_bh % head_kv
i_v, i_t, i_bh = bx, by, bz NS = S
i_b, i_h = i_bh // heads, i_bh % heads T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared) T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(acc_o, 0) T.fill(scores_max, -T.infinity(accum_dtype))
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
for i in T.Pipelined(NS, num_stages=num_stages): if i_s <= i_t and i_s >= 0:
i_s = BlockIndices[i_b, i_t, i_h, i] # [BS, BK]
if i_s <= i_t: T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
# Q * K
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], K_shared) if is_causal:
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s * BS + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(G, BS): for i, j in T.Parallel(G, BS):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) - acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
# max * log_2(e)) This allows the compiler to use the ffma -T.infinity(acc_s.dtype))
# instruction instead of fadd and fmul separately. else:
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.clear(acc_s)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G): T.gemm(
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] Q_shared,
T.copy(acc_s, acc_s_cast) K_shared,
acc_s,
# Rescale transpose_B=True,
for i, j in T.Parallel(G, BV): policy=T.GemmWarpPolicy.FullRow)
acc_o[i, j] *= scores_scale[i]
# Softmax
# V * softmax(Q * K) T.copy(scores_max, scores_max_prev)
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :], V_shared) T.fill(scores_max, -T.infinity(accum_dtype))
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
for i, j in T.Parallel(G, BV): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
acc_o[i, j] /= logsum[i] for i, j in T.Parallel(G, BS):
T.copy(acc_o, O_shared) acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, :]) T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
return main logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
def kernel(block_S, block_T, num_stages, threads):
return kernel_func(block_S, block_T, num_stages, threads) # Rescale
for i, j in T.Parallel(G, BV):
return kernel(block_S, block_T, num_stages, threads) acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
return native_sparse_attention
if __name__ == "__main__": if __name__ == "__main__":
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 1, 64, 4, 64, 32, 16, 64, torch.float16, None B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
program = native_sparse_attention( program = native_sparse_attention(
batch=B, batch=B,
...@@ -144,15 +133,16 @@ if __name__ == "__main__": ...@@ -144,15 +133,16 @@ if __name__ == "__main__":
seq_len=SEQ_LEN, seq_len=SEQ_LEN,
dim=D, dim=D,
is_causal=True, is_causal=True,
scale=scale, block_size=block_size,
groups=HQ // H, groups=HQ // H,
selected_blocks=S, selected_blocks=S,
) )
kernel = tilelang.compile(program, out_idx=[4]) kernel = tilelang.compile(program, out_idx=-1)
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
...@@ -166,16 +156,14 @@ if __name__ == "__main__": ...@@ -166,16 +156,14 @@ if __name__ == "__main__":
out = kernel(Q, K, V, block_indices.to(torch.int32)) out = kernel(Q, K, V, block_indices.to(torch.int32))
print(out) ref = naive_nsa_simple(
ref = naive_nsa(
q=Q, q=Q,
k=K, k=K,
v=V, v=V,
block_indices=block_indices, block_indices=block_indices,
block_counts=block_counts, block_counts=block_counts,
block_size=block_size, block_size=block_size)
scale=scale)
print(ref) print("out", out)
print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
...@@ -109,3 +109,159 @@ def naive_nsa(q: torch.Tensor, ...@@ -109,3 +109,159 @@ def naive_nsa(q: torch.Tensor,
if head_first: if head_first:
o = rearrange(o, 'b t h d -> b h t d') o = rearrange(o, 'b t h d -> b h t d')
return o.to(dtype) return o.to(dtype)
def naive_nsa_simple(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[i_q]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[i_q]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment