"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "31a247b7a9f44b1fbb71c5a35c10fb71b8bcca5f"
Commit 7ccec53b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feature] Support Async Pipeline inference within if scope (#198)

* Optimize CMake build process with dynamic job count calculation

- Modify build_csrc function to use 90% of available CPU cores
- Ensure at least one job is used during compilation
- Improve build performance by dynamically adjusting parallel job count

* Optimize build_csrc function with multiprocessing module

- Replace os.cpu_count() with multiprocessing.cpu_count()
- Maintain existing 90% CPU utilization logic
- Improve CPU core count calculation for build process

* Add dynamic shape support with out_idx in Cython JIT kernel compilation

- Implement `run_cython_dynamic_shape_with_out_idx` function in test_tilelang_jit_gemm_cython.py
- Update Cython wrapper to handle dynamic symbolic shapes during tensor allocation
- Add support for resolving dynamic shape dimensions using input tensor references
- Enhance flexibility of JIT kernel compilation with symbolic shape handling

* Enhance error reporting for dynamic symbolic shape resolution in Cython JIT kernel

- Add detailed error message when a dynamic symbolic dimension is not found in dynamic_symbolic_map
- Improve debugging by providing context about missing symbolic dimensions
- Maintain existing dynamic shape resolution logic

* Fix Copy operation handling for scalar and multi-dimensional tensors

- Add special handling for scalar tensor copy operations
- Enhance error reporting in MakeIndices method with more detailed diagnostic information
- Improve SIMT loop generation to support zero-dimensional tensors
- Add explicit check and handling for scalar tensor scenarios

* Refactor Copy operation code formatting and improve readability

- Improve code formatting in MakeIndices and MakeSIMTLoop methods
- Add line breaks to enhance readability of complex ICHECK statements
- Simplify code structure in scalar tensor handling
- Remove unnecessary whitespace and improve code alignment

* Simplify GEMM example with direct kernel compilation

- Update copyright header to Tile-AI Corporation
- Remove Profiler import and usage
- Replace tilelang.lower() with tilelang.compile()
- Simplify kernel execution workflow
- Update kernel source retrieval method

* Enhance block sparse attention implementation

- Update `blocksparse_flashattn` to use 2 stages for improved performance.
- Change `block_mask_dtype` from `int8` to `bool` for better memory efficiency.
- Modify condition checks in the kernel to utilize boolean values.
- Introduce a new example for top-k sparse attention and a benchmark for native sparse attention.
- Add support for asynchronous copy in PTX and improve pipeline planning with condition handling.

* Refactor and clean up code formatting across multiple files

- Added whitespace for improved readability in `example_blocksparse_gemm.py`, `example_tilelang_nsa_fwd.py`, and `benchmark_nsa_fwd.py`.
- Enhanced code structure and alignment in `inject_ptx_async_copy.cc` and `pipeline_planning.cc`.
- Updated comments and documentation for clarity in `__init__.py` and `phase.py`.
- Ensured consistent formatting and style across the codebase.
parent 20f19611
......@@ -39,7 +39,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 0
num_stages = 2
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
......@@ -47,7 +47,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "int8"
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
......@@ -159,7 +159,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k] != 0:
if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
......@@ -187,8 +187,6 @@ def benchmark_topk_sparse_attention():
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import math
import torch
......@@ -34,7 +32,7 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 0
num_stages = 1
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
......@@ -42,7 +40,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "int8"
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
......@@ -196,7 +194,7 @@ def test_topk_sparse_attention():
# Run Triton kernel
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4])
print(kernel.get_kernel_source())
tilelang_output = kernel(q, k, v, block_mask)
# Compute reference
......@@ -215,8 +213,7 @@ def test_topk_sparse_attention():
print("tilelang_output", tilelang_output)
# Verify accuracy
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
"TileLang output doesn't match reference"
torch.testing.assert_close(tilelang_output, ref_output, atol=1e-2, rtol=1e-2)
print("Pass topk sparse attention test with qlen == klen")
......
import tilelang
import tilelang.language as T
import torch
torch.random.manual_seed(0)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
BlockMask: T.Buffer(block_mask_shape, "bool"),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
if BlockMask[by, bx, k]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
# block_mask = torch.zeros(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# block_mask = torch.ones(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# random mask
block_mask = torch.randint(0, 2, (1024 // 128, 1024 // 128, 1024 // 32)).cuda().bool()
c = kernel(a, b, block_mask)
ref_c = torch.zeros_like(c)
for i in range(1024 // 128):
for j in range(1024 // 128):
accu = torch.zeros((128, 128), dtype=torch.float32, device=a.device)
for k in range(1024 // 32):
if block_mask[i, j, k]:
accu += (
a[i * 128:(i + 1) * 128, k * 32:(k + 1) * 32].to(torch.float32)
@ b[k * 32:(k + 1) * 32, j * 128:(j + 1) * 128].to(torch.float32))
ref_c[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = accu.to(torch.float16)
# ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print(kernel.get_kernel_source())
......@@ -211,9 +211,10 @@ if __name__ == "__main__":
if (not args.tune):
program = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=args.tune)(
block_M=64, block_N=64, num_stages=0, threads=128)
block_M=64, block_N=64, num_stages=1, threads=128)
ref_program = partial(ref_program, is_causal=is_causal)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......
# ruff: noqa
import torch
import time
import argparse
import tilelang
from tilelang import language as T
import tilelang.testing
from typing import Optional, Union
from einops import rearrange, repeat
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`.
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def tilelang_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
num_stages = 2
threads = 32
@T.prim_func
def tilelang_sparse_attention(
Q: T.Buffer(q_shape, dtype),
K: T.Buffer(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype),
Output: T.Buffer(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
return tilelang_sparse_attention
def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size):
"""Generate random block indices for the benchmark."""
block_indices = torch.full((batch, seq_len, heads, selected_blocks),
seq_len,
dtype=torch.long,
device='cuda')
for b in range(batch):
for t in range(seq_len):
for h in range(heads):
i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks]
block_indices[b, t, h, :len(i_i)] = i_i
return block_indices.sort(-1)[0]
def benchmark_nsa(batch_size,
seq_len,
heads,
head_query,
dim,
selected_blocks,
block_size,
dtype,
scale,
warmup=10,
iterations=100,
validate=False):
"""Benchmark the TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
tilelang.testing.set_random_seed(0)
torch.random.manual_seed(0)
# Compile the NSA kernel
program = tilelang_sparse_attention(
batch=batch_size,
heads=head_query,
seq_len=seq_len,
dim=dim,
is_causal=True,
block_size=block_size,
groups=head_query // heads,
selected_blocks=selected_blocks,
scale=scale,
)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
print(kernel.get_kernel_source())
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
# Warmup
for _ in range(warmup):
out = kernel(Q, K, V, block_indices.to(torch.int32))
# Synchronize before timing
torch.cuda.synchronize()
# Benchmark
start_time = time.time()
for _ in range(iterations):
out = kernel(Q, K, V, block_indices.to(torch.int32))
torch.cuda.synchronize()
end_time = time.time()
# Calculate metrics
elapsed_time = end_time - start_time
avg_time = elapsed_time / iterations * 1000 # ms
# Calculate FLOPs (approximate for NSA)
# Each token attends to selected_blocks * block_size tokens
# Each attention calculation involves 2*dim FLOPs for QK
# And another 2*dim FLOPs for attention * V
flops_per_token = 4 * dim * selected_blocks * block_size
total_flops = batch_size * seq_len * head_query * flops_per_token
flops_per_sec = total_flops / (elapsed_time / iterations)
tflops = flops_per_sec / 1e12
# Validate result against reference if requested
if validate:
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
block_counts = torch.randint(
1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda')
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2)
if is_valid:
print("Validation: PASSED")
else:
print("Validation: FAILED")
print(f"Max difference: {(ref - out).abs().max().item()}")
# Return benchmark results
return {
"avg_time_ms": avg_time,
"tflops": tflops,
"batch_size": batch_size,
"seq_len": seq_len,
"heads": heads,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size
}
def benchmark_triton_nsa(batch_size,
seq_len,
heads,
head_query,
dim,
selected_blocks,
block_size,
dtype,
scale,
warmup=10,
iterations=100,
validate=False):
"""Benchmark the Triton-based TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
tilelang.testing.set_random_seed(0)
torch.random.manual_seed(0)
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
block_counts = torch.randint(
1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda')
# Warmup
for _ in range(warmup):
out = parallel_nsa_fwd(
q=Q,
k=K,
v=V,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale)
# Synchronize before timing
torch.cuda.synchronize()
# Benchmark
start_time = time.time()
for _ in range(iterations):
out = parallel_nsa_fwd(
q=Q,
k=K,
v=V,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale)
torch.cuda.synchronize()
end_time = time.time()
# Calculate metrics
elapsed_time = end_time - start_time
avg_time = elapsed_time / iterations * 1000 # ms
# Calculate FLOPs (approximate for NSA)
flops_per_token = 4 * dim * selected_blocks * block_size
total_flops = batch_size * seq_len * head_query * flops_per_token
flops_per_sec = total_flops / (elapsed_time / iterations)
tflops = flops_per_sec / 1e12
# Validate result against reference if requested
if validate:
ref = naive_nsa(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
is_valid = torch.allclose(ref, out, atol=1e-2, rtol=1e-2)
if is_valid:
print("Validation: PASSED")
else:
print("Validation: FAILED")
print(f"Max difference: {(ref - out).abs().max().item()}")
# Return benchmark results
return {
"avg_time_ms": avg_time,
"tflops": tflops,
"batch_size": batch_size,
"seq_len": seq_len,
"heads": heads,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size
}
def run_benchmark_suite(impl='all'):
"""Run a suite of benchmarks with different configurations."""
# Define configurations to benchmark
configs = [
# Small model config - Note: head_query must be a multiple of heads*16 for Triton
{
"batch_size": 2,
"seq_len": 1024,
"heads": 8,
"head_query": 8 * 16,
"dim": 64,
"selected_blocks": 8,
"block_size": 32
},
# Medium model config
{
"batch_size": 2,
"seq_len": 2048,
"heads": 16,
"head_query": 16 * 16,
"dim": 64,
"selected_blocks": 16,
"block_size": 64
},
# Large model config
{
"batch_size": 1,
"seq_len": 4096,
"heads": 32,
"head_query": 32 * 16,
"dim": 128,
"selected_blocks": 32,
"block_size": 128
},
]
results = []
for config in configs:
print(f"Running benchmark with config: {config}")
if impl in ['all', 'tilelang']:
print("Benchmarking TileLang implementation:")
result = benchmark_nsa(
batch_size=config["batch_size"],
seq_len=config["seq_len"],
heads=config["heads"],
head_query=config["head_query"],
dim=config["dim"],
selected_blocks=config["selected_blocks"],
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False)
results.append({"impl": "tilelang", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ['all', 'triton']:
print("Benchmarking Triton implementation:")
result = benchmark_triton_nsa(
batch_size=config["batch_size"],
seq_len=config["seq_len"],
heads=config["heads"],
head_query=config["head_query"],
dim=config["dim"],
selected_blocks=config["selected_blocks"],
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False)
results.append({"impl": "triton", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ['all']:
# Print comparison if both implementations were run
tilelang_result = next(
r for r in results if r["impl"] == "tilelang" and
r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"])
triton_result = next(
r for r in results if r["impl"] == "triton" and
r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"])
speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"]
print(f"Speedup (Triton vs TileLang): {speedup:.2f}x")
print("-" * 50)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark TileLang Sparse Attention")
parser.add_argument("--batch", type=int, default=2, help="Batch size")
parser.add_argument("--seq_len", type=int, default=1024, help="Sequence length")
parser.add_argument("--heads", type=int, default=1, help="Number of heads")
parser.add_argument("--head_query", type=int, default=16, help="Number of query heads")
parser.add_argument("--dim", type=int, default=64, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=8, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=64, help="Block size")
parser.add_argument(
"--dtype", type=str, default="float16", help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
parser.add_argument("--validate", action="store_true", help="Validate against reference")
parser.add_argument("--suite", action="store_true", help="Run benchmark suite")
parser.add_argument(
"--impl",
type=str,
default="tilelang",
choices=["tilelang", "triton", "all"],
help="Implementation to benchmark (tilelang, triton, or all)")
args = parser.parse_args()
# For Triton impl, ensure head_query is a multiple of heads*16
if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0:
# Adjust head_query to nearest valid value
args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16)
print(
f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation")
if args.suite:
run_benchmark_suite(impl=args.impl)
else:
dtype = torch.float16 if args.dtype == "float16" else torch.float32
if args.impl in ["tilelang", "all"]:
print("Benchmarking TileLang implementation:")
result = benchmark_nsa(
batch_size=args.batch,
seq_len=args.seq_len,
heads=args.heads,
head_query=args.head_query,
dim=args.dim,
selected_blocks=args.selected_blocks,
block_size=args.block_size,
dtype=dtype,
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate)
print("\nBenchmark Results (TileLang):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " +
f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " +
f"block_size={args.block_size}")
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if args.impl in ["triton", "all"]:
print("Benchmarking Triton implementation:")
result = benchmark_triton_nsa(
batch_size=args.batch,
seq_len=args.seq_len,
heads=args.heads,
head_query=args.head_query,
dim=args.dim,
selected_blocks=args.selected_blocks,
block_size=args.block_size,
dtype=dtype,
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate)
print("\nBenchmark Results (Triton):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " +
f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " +
f"block_size={args.block_size}")
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
......@@ -40,7 +40,7 @@ def native_sparse_attention(batch,
G = groups
BS = block_S
BK = BV = block_T
num_stages = 0
num_stages = 2
threads = 32
@T.prim_func
......@@ -140,6 +140,7 @@ if __name__ == "__main__":
scale=scale,
)
kernel = tilelang.compile(program, out_idx=-1)
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
......
......@@ -158,84 +158,6 @@ def naive_nsa(q: torch.Tensor,
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
def naive_nsa_simple(
q: torch.Tensor,
k: torch.Tensor,
......
......@@ -227,11 +227,12 @@ class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info)
const For &pipeline_loop, const PipelineInfo &pipeline_info,
PrimExpr predicate_condition = PrimExpr())
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {}
pipeline_info_(pipeline_info),
predicate_condition_(predicate_condition) {}
Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
......@@ -636,6 +637,7 @@ private:
// Async related
std::map<int, AsyncStateLocal> async_states_local;
PrimExpr normalized_access_index;
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage;
......@@ -658,7 +660,7 @@ private:
// - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written
// buffers.
PrimExpr normalized_access_index =
normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop
......@@ -668,10 +670,15 @@ private:
Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
}
new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (predicate_condition_.defined()) {
BlockNode *n = new_block.CopyOnWrite();
n->body = IfThenElse(
Substitute(predicate_condition_,
{{pipeline_loop_->loop_var, normalized_access_index}}),
n->body);
}
if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
......@@ -687,7 +694,6 @@ private:
PopulateWaitCounts(new_blocks, &async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
Stmt new_loop{nullptr};
if (stmts.empty()) {
......@@ -713,7 +719,6 @@ private:
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), NullOpt, preserved_annotations);
}
// Update producer heads in the global async states.
for (const auto &[stage_id, state] : async_states_local) {
async_states[stage_id].producer_head += extent;
......@@ -728,6 +733,7 @@ private:
Array<Buffer> pipeline_allocs_;
For pipeline_loop_;
PipelineInfo pipeline_info_;
PrimExpr predicate_condition_;
int max_stage_ = -1;
Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_;
......@@ -842,6 +848,7 @@ private:
// can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr};
PrimExpr predicate_condition{nullptr};
Array<Buffer> pipeline_allocs;
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......@@ -849,7 +856,15 @@ private:
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
predicate_condition = if_then_else->condition;
} else {
pipeline_body = block->body;
}
pipeline_allocs = block->alloc_buffers;
} else {
pipeline_body = for_node->body;
......@@ -927,9 +942,10 @@ private:
ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
.BuildPipeline();
Stmt pipeline =
PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info, predicate_condition)
.BuildPipeline();
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "storage_access.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
namespace tvm {
namespace tl {
using namespace tir;
class PTXAsyncCopyInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode *attr) {
if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
auto body = this->VisitStmt(attr->body);
in_async = false;
return body;
}
return StmtMutator::VisitStmt_(attr);
}
Stmt InjectPTX(const BufferLoadNode *load, const BufferStoreNode *store,
bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() ==
store->indices[0]->dtype.lanes())
<< load->indices[0] << " vs. " << store->indices[0] << " with lanes "
<< load->indices[0]->dtype.lanes() << " vs. "
<< store->indices[0]->dtype.lanes();
const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type =
GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type =
GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type "
"annotation.";
int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the
// dst shared memory is a byte buffer generated by merging dynamic
// shared memory.
ICHECK(store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "shared");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according
// to their "value" dtype. Their "indices" are supposed to be applied
// after such pointer cast, for example:
// ((*float16)(byte_buffer))[buffer->indices] = fp16_value; To replace
// BufferStore/Load with cp.async, we need to multiply the store index
// by the byte size of the "value" dtype, to get the correct offset
// into the byte buffer.
index_factor = src_elem_type->bytes();
}
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {
store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated
// cp.async
if (predicated) {
args.push_back(predicate_value);
}
return Evaluate(Call(store->buffer->dtype,
tvm::tir::builtin::ptx_cp_async(), args));
}
// Predicated load don't support vectorized indexing.
if (!predicated) {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by
// merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) +
// x8(17408))] = A_global[ramp(...),1, 8)]
auto *add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>())
return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>())
return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base,
add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by
// merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) +
// x8(17408))] = A_global[ramp(...),1, 8)]
auto *add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>())
return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>())
return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base,
add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes),
predicate_value}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
Stmt VisitStmt_(const BufferStoreNode *store) {
if (in_async && (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn")) {
if (auto *load = store->value.as<BufferLoadNode>()) {
return InjectPTX(load, store);
} else if (auto *call = store->value.as<CallNode>()) {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) &&
call->args.size() == 3) {
if (auto *load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value
// used by cp.async ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto *b = call->args[2].as<BroadcastNode>()) {
if (auto *f = b->value.as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
} else if (auto *i = b->value.as<IntImmNode>()) {
else_value_is_zero = i->value == 0;
}
}
if (auto *f = call->args[2].as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
} else if (auto *i = call->args[2].as<IntImmNode>()) {
else_value_is_zero = i->value == 0;
}
if (else_value_is_zero) {
return InjectPTX(load, store, true, call->args[0]);
}
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
private:
bool in_async{false};
};
using namespace tir::transform;
tvm::transform::Pass InjectPTXAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = PTXAsyncCopyInjector()(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
}
TVM_REGISTER_GLOBAL("tl.transform.InjectPTXAsyncCopy")
.set_body_typed(InjectPTXAsyncCopy);
} // namespace tl
} // namespace tvm
......@@ -34,8 +34,6 @@ namespace tl {
using namespace tir;
namespace {
/*!
* \brief Check whether two regions have intersections.
* \param region1 The first region.
......@@ -56,8 +54,6 @@ bool MayConflict(Region region1, Region region2) {
return true;
}
} // namespace
class PipelinePlanner : public StmtExprMutator {
public:
static Stmt Substitute(const PrimFunc &f) {
......@@ -88,20 +84,24 @@ private:
/*body*/ stmt);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]);
pinfo.original_order = idx;
// copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
}
bool write_to_shared = false;
bool read_from_global = false;
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global")
read_from_global = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "shared" ||
region->buffer.scope() == "shared.dyn")
write_to_shared = true;
pinfo.copy_stage = write_to_shared && read_from_global;
return std::move(pinfo);
}
......@@ -118,14 +118,26 @@ private:
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
pipeline_body = block->body;
} else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
// should assert else case is nullptr
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
} else {
LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
"because it is not a SeqStmt or IfThenElse";
}
} else {
pipeline_body = loop->body;
}
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< loop->body->GetTypeKey();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< pipeline_body->GetTypeKey() << " " << pipeline_body;
CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial);
......@@ -156,10 +168,12 @@ private:
return r->buffer == write->buffer &&
MayConflict(r->region, write->region);
}) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer "
"region in the pipeline "
"planning pass: "
<< pipeline_body_seq->seq[pinfo.original_order];
LOG(FATAL) << "Pipeline planning error: Multiple writes to "
"overlapping buffer regions detected. "
<< "Stage " << pinfo.original_order << " and stage " << i
<< " are both writing to buffer '" << write->buffer->name
<< "' with overlapping regions. This is not supported "
"in pipeline planning.";
}
}
}
......
......@@ -44,6 +44,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# TODO(lei): may need a pass to fuse the if-then-else in the
# pipeline loop when we meet dynamic branch.
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
......@@ -74,7 +76,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
......
......@@ -225,3 +225,14 @@ def VectorizeLoop(enable_vectorize: bool = True):
The result pass
"""
return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore
def InjectPTXAsyncCopy():
"""Rewrite global to shared memory copy on CUDA with asynchronous copy.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment