"...text-generation-inference.git" did not exist on "54fec9319371b2792526e0cbfebe6cee66ed3980"
Commit b6c48453 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Cast bool dtype into int8 in blocksparse examples (#167)

* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation

- Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py
- Modify roller hints generation using new TileLang Carver template and utility functions
- Update get_roller_hints_from_func to handle None cases and improve return logic
- Adjust DefaultPolicy to handle different codegen dictionary formats

* [Refactor] Update Thread Binding and Import Statements in TileLang Kernels

- Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files
- Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples
- Move map_torch_type utility function to tilelang.utils.tensor
- Remove unnecessary imports and improve code organization

* Refactor Native Sparse Attention Example with Enhanced Triton Kernel

- Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation
- Add support for block counts and offsets in the Triton kernel
- Modify kernel grid and computation logic for improved performance
- Update example script to use naive_nsa_simple reference implementation
- Improve type hints and kernel configuration
parent de1ba1e4
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional from typing import Optional, Union
import torch import torch
import triton import triton
...@@ -12,34 +12,25 @@ from einops import rearrange ...@@ -12,34 +12,25 @@ from einops import rearrange
from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices) from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices)
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa from reference import naive_nsa, naive_nsa_simple
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune( @triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16]], configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'], key=['BS', 'BK', 'BV'],
) )
@triton.jit @triton.jit
def parallel_nsa_fwd_kernel( def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
q, block_counts, offsets, token_indices, T, H: tl.constexpr,
k, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
v, S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
o, BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
lse, USE_BLOCK_COUNTS: tl.constexpr):
scale, i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
block_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T bos, eos = i_b * T, i_b * T + T
...@@ -48,65 +39,99 @@ def parallel_nsa_fwd_kernel( ...@@ -48,65 +39,99 @@ def parallel_nsa_fwd_kernel(
v += (bos * H + i_h) * V v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0)) (1, 0))
p_o = tl.make_block_ptr(o + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV),
(1, 0))
p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# the Q block is kept in the shared memory throughout the whole kernel # the Q block is kept in the shared memory throughout the whole kernel
# [G, BK] # [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype) b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV] # [G, BV]
b_o = tl.zeros([G, BV], dtype=tl.float32) b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m = tl.full([G], float('-inf'), dtype=tl.float32) b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc = tl.zeros([G], dtype=tl.float32) b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS): for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t: if i_s <= i_t and i_s >= 0:
p_k = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1)) p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS] # [BK, BS]
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV] # [BS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1)) b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS] # [G, BS]
b_s = tl.dot(b_q, b_k) b_s_slc = tl.dot(b_q, b_k_slc)
b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) if i_t == 6:
print("b_s_slc", b_s_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G] # [G]
b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r = tl.exp(b_mp - b_m) b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS] # [G, BS]
b_p = tl.exp(b_s - b_m[:, None]) b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G] # [G]
b_acc = b_acc * b_r + tl.sum(b_p, 1) b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV] # [G, BV]
b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
b_mp = b_m o, lse = parallel_nsa_fwd(
b_o = b_o / b_acc[:, None] q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
b_m += tl.log(b_acc) ctx.save_for_backward(q, k, v, o, lse)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) ctx.block_indices = block_indices
tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa_fwd( def parallel_nsa_fwd(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
block_indices: torch.Tensor, block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int, block_size: int,
scale: float, window_size: int,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
): ):
import math
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2] HQ = q.shape[2]
scale = 1.0 / math.sqrt(K)
G = HQ // H G = HQ // H
BS = block_size BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9: if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K)) BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V)) BV = min(256, triton.next_power_of_2(V))
...@@ -117,106 +142,43 @@ def parallel_nsa_fwd( ...@@ -117,106 +142,43 @@ def parallel_nsa_fwd(
NV = triton.cdiv(V, BV) NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256" assert NK == 1, "The key dimension can not be larger than 256"
grid = (NV, T, B * H) grid = (T, NV, B * H)
o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
lse = torch.empty(B, T, HQ, dtype=torch.float32, device=q.device) o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
print("grid", grid) lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid]( parallel_nsa_fwd_kernel[grid](
q=q, q=q,
k=k, k=k,
v=v, v=v,
o=o, o_slc=o_slc,
lse=lse, o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale, scale=scale,
block_indices=block_indices, block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H, H=H,
HQ=HQ, HQ=HQ,
G=G, G=G,
T=T,
K=K, K=K,
V=V, V=V,
S=S, S=S,
BS=BS, BS=BS,
WS=WS,
BK=BK, BK=BK,
BV=BV, BV=BV,
) )
return o, lse return o_slc, lse_slc, o_swa, lse_swa
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_size: int = 64,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_size (int):
Selected block size. Default: 64.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
o = ParallelNSAFunction.apply(q, k, v, block_indices, block_size, scale, cu_seqlens)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__": if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype, scale = 1, 64, 1, 16, 32, 1, 64, torch.float16, 0.1 B, T, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True) v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
...@@ -232,32 +194,25 @@ if __name__ == "__main__": ...@@ -232,32 +194,25 @@ if __name__ == "__main__":
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda') block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
ref = naive_nsa( ref = naive_nsa_simple(
q=q, q=q,
k=k, k=k,
v=v, v=v,
block_indices=block_indices, block_indices=block_indices,
block_counts=block_counts, block_counts=block_counts,
block_size=block_size, block_size=block_size,
scale=scale) )
# print(ref)
tri = parallel_nsa( tri, _, _, _ = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale) q=q,
k=k,
v=v,
block_indices=block_indices,
block_size=block_size,
window_size=0,
block_counts=block_counts)
# print(tri) print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# import flash_attn
# # gqa
# o_gqa = flash_attn.flash_attn_func(
# q,
# k,
# v,
# softmax_scale=scale,
# )
# print(o_gqa)
# torch.testing.assert_close(o_gqa, tri, atol=1e-2, rtol=1e-2)
...@@ -181,7 +181,7 @@ def test_topk_sparse_attention(): ...@@ -181,7 +181,7 @@ def test_topk_sparse_attention():
BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4]) kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
...@@ -231,13 +231,7 @@ def test_topk_sparse_attention_qlen_lt_klen(): ...@@ -231,13 +231,7 @@ def test_topk_sparse_attention_qlen_lt_klen():
print(program) print(program)
kernel = tilelang.compile(program, out_idx=[4]) kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask) tilelang_output = kernel(q, k, v, block_mask.to(torch.int8))
# import flash_attn
# ref_out = flash_attn.flash_attn_func(q, k, v, causal=True)
# torch.testing.assert_close(tilelang_output, ref_out, atol=1e-2, rtol=1e-2)
# exit()
past_len = K_LEN - Q_LEN past_len = K_LEN - Q_LEN
...@@ -268,5 +262,5 @@ def test_topk_sparse_attention_qlen_lt_klen(): ...@@ -268,5 +262,5 @@ def test_topk_sparse_attention_qlen_lt_klen():
if __name__ == "__main__": if __name__ == "__main__":
# test_topk_sparse_attention() test_topk_sparse_attention()
test_topk_sparse_attention_qlen_lt_klen() test_topk_sparse_attention_qlen_lt_klen()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment