Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -39,19 +39,9 @@ repos: ...@@ -39,19 +39,9 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7 # sync with requirements-lint.txt rev: v0.14.7 # sync with requirements-lint.txt
hooks: hooks:
- id: ruff-format
- id: ruff-check - id: ruff-check
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.43.0 # sync with requirements-lint.txt
hooks:
- id: yapf
name: yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args: [--in-place, docs/conf.py]
always_run: true
pass_filenames: false
- id: yapf
args: [--recursive, --in-place]
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # sync with requirements-lint.txt rev: v2.4.1 # sync with requirements-lint.txt
hooks: hooks:
...@@ -62,4 +52,4 @@ repos: ...@@ -62,4 +52,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$| ^.+\.svg$|
^.*\brequirements\b.*\.txt$ ^.*\brequirements\b.*\.txt$
) )
\ No newline at end of file
...@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
import flash_attn import flash_attn
......
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64 block_N = 64
num_stages = 2 num_stages = 2
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
...@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj] block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k]: if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
...@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn( program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=4) kernel = tilelang.compile(program, out_idx=4)
def benchmark_fn(): def benchmark_fn():
......
...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F ...@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
def benchmark_fn(): def benchmark_fn():
# Compute reference # Compute reference
# Expand block mask to full attention matrix # Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation # PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf')) attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
return ref_output return ref_output
ref_latency = do_bench( ref_latency = do_bench(
......
...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): ...@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK # N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
False,
dtype=torch.bool,
device=x.device)
dense_mask.scatter_(-1, sparse_index, True) dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block: if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True dense_mask[:, :, -2:, :] = True
...@@ -56,7 +53,6 @@ def _fwd_kernel_inner( ...@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True: if mask_val == True:
...@@ -72,8 +68,7 @@ def _fwd_kernel_inner( ...@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK: if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
float('-inf'))
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None] qk -= m_ij[:, None]
...@@ -153,7 +148,7 @@ def _fwd_kernel( ...@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs = V + off_v v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
...@@ -191,24 +186,12 @@ def _fwd_kernel( ...@@ -191,24 +186,12 @@ def _fwd_kernel(
acc = acc * l_recip acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty) acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
None, :] * stride_od
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx, def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2] assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous() o = out if out is not None else torch.empty_like(q).contiguous()
...@@ -253,7 +236,6 @@ def _forward(ctx, ...@@ -253,7 +236,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function): class _sparse_attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints # shape constraints
...@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply ...@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def benchmark_topk_sparse_attention(): def benchmark_topk_sparse_attention():
from benchmark_configs import configs from benchmark_configs import configs
torch.manual_seed(0) torch.manual_seed(0)
# Config # Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs # Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5) sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level) # Create sparse mask (downsampled to block level)
downsample_factor = BLOCK downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor) downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
device='cuda',
dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100 x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......
...@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): ...@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum) decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril( causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0) scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), out = torch.einsum(
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) "bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( out_prev = (
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p") out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None: if D is not None:
...@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): ...@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
@helion.kernel() @helion.kernel()
def helion_mamba2_chunk_scan_kernel( def helion_mamba2_chunk_scan_kernel(
cb: torch.Tensor, cb: torch.Tensor,
...@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype = cb.dtype dtype = cb.dtype
accum_dtype = torch.float32 accum_dtype = torch.float32
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype
dtype)
out = torch.empty_like(x) out = torch.empty_like(x)
...@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
[nheads, chunk_size, headdim, batch, nchunks], [nheads, chunk_size, headdim, batch, nchunks],
block_size=[1, block_m, block_n, 1, 1], block_size=[1, block_m, block_n, 1, 1],
): ):
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32)
tile_m].to(torch.float32)
scale_m_local = torch.exp2(dA_cumsum_local_m * p) scale_m_local = torch.exp2(dA_cumsum_local_m * p)
C_local = C[ C_local = C[
...@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m, tile_m,
tile_k, tile_k,
] ]
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
tile_k].to(torch.float32) cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
dA_cumsum_local_k[None, :] * p)
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local = (cb_local * dt_local[None, :]).to(dtype) cb_local = (cb_local * dt_local[None, :]).to(dtype)
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
...@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o = hl.dot(cb_local, x_local, acc=acc_o) acc_o = hl.dot(cb_local, x_local, acc=acc_o)
D_local = D[tile_h.begin].to(torch.float32) D_local = D[tile_h.begin].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32)
tile_n].to(torch.float32)
acc_o += x_residual * D_local acc_o += x_residual * D_local
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype)
tile_n] = acc_o.to(dtype=dtype)
return out return out
...@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): ...@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def get_configs(): def get_configs():
iter_params = dict( iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
...@@ -198,19 +187,21 @@ def get_configs(): ...@@ -198,19 +187,21 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def chunk_scan_fwd(batch, def chunk_scan_fwd(
seqlen, batch,
chunk_size, seqlen,
ngroups, chunk_size,
nheads, ngroups,
headdim, nheads,
dstate, headdim,
block_M=64, dstate,
block_N=64, block_M=64,
block_K=64, block_N=64,
block_Dstate=128, block_K=64,
num_stages=2, block_Dstate=128,
threads=128): num_stages=2,
threads=128,
):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
...@@ -218,20 +209,20 @@ def chunk_scan_fwd(batch, ...@@ -218,20 +209,20 @@ def chunk_scan_fwd(batch,
@T.prim_func @T.prim_func
def main( def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
nheads, bz,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), bx,
batch * nchunks, by,
threads=threads) as (bz, bx, by): ):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
...@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch, ...@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N) m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({ T.annotate_layout(
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), {
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
}) x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg() T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local) T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o) T.clear(acc_o)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy( T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + C[
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) batch_idx,
T.copy( chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, bz // (nheads // ngroups),
0:block_Dstate], prev_state_shared) 0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i] acc_o[i, j] *= scale_m_local[i]
...@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch, ...@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups), cb[
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], batch_idx,
cb_shared) chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local) T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local) T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
j] = cb_local[i, T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local) T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j] cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
cb_local[i, j], 0)
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + x[
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o) T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz] D_local[0] = D[bz]
T.copy( T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + x[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], batch_idx,
x_residual_shared) chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local) T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0] acc_o[i, j] += x_residual_local[i, j] * D_local[0]
...@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch, ...@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared) T.copy(acc_o, acc_o_shared)
T.copy( T.copy(
acc_o_shared, acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + Output[
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main return main
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=80, help='heads') parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--dstate', type=int, default=128, help='dstate') parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
nchunks = math.ceil(seq_len / chunk_size) nchunks = math.ceil(seq_len / chunk_size)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
...@@ -360,8 +382,7 @@ if __name__ == "__main__": ...@@ -360,8 +382,7 @@ if __name__ == "__main__":
D = torch.randn(heads).half().cuda() D = torch.randn(heads).half().cuda()
print("Benchmarking Triton...") print("Benchmarking Triton...")
triton_latency = do_bench( triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
print("Benchmarking Helion...") print("Benchmarking Helion...")
......
...@@ -6,6 +6,7 @@ import tilelang ...@@ -6,6 +6,7 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune from tilelang.autotuner import autotune
from tilelang import jit from tilelang import jit
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
...@@ -101,9 +102,7 @@ def get_configs(args, kwargs): ...@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square], policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False], enable_rasteration=[True, False],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs return configs
...@@ -112,7 +111,9 @@ def get_configs(args, kwargs): ...@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit(out_idx=[2],) @jit(
out_idx=[2],
)
def matmul( def matmul(
M, M,
N, N,
...@@ -159,9 +160,9 @@ def matmul( ...@@ -159,9 +160,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -176,7 +177,6 @@ def matmul( ...@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
......
...@@ -6,7 +6,8 @@ import tilelang as tl ...@@ -6,7 +6,8 @@ import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune from tilelang.autotuner import autotune
import itertools import itertools
...@@ -103,12 +104,11 @@ def tl_matmul( ...@@ -103,12 +104,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -116,10 +116,12 @@ def tl_matmul( ...@@ -116,10 +116,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10, enable=enable_rasteration) T.use_swizzle(panel_size=10, enable=enable_rasteration)
...@@ -127,7 +129,6 @@ def tl_matmul( ...@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -137,7 +138,6 @@ def tl_matmul( ...@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki) mma_emitter.ldmatrix_a(A_local, A_shared, ki)
...@@ -223,7 +223,6 @@ def get_configs(args, kwargs): ...@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for config in configs: for config in configs:
print(config) print(config)
else: else:
iter_params = dict( iter_params = dict(
block_row_warps=[1, 2, 4], block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4], block_col_warps=[1, 2, 4],
...@@ -233,9 +232,7 @@ def get_configs(args, kwargs): ...@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage=[0, 2], stage=[0, 2],
enable_rasteration=[True, False], enable_rasteration=[True, False],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs return configs
...@@ -247,7 +244,9 @@ def get_configs(args, kwargs): ...@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
ref_prog=ref_program, ref_prog=ref_program,
skip_check=True, skip_check=True,
) )
@tl.jit(out_idx=[2],) @tl.jit(
out_idx=[2],
)
def matmul( def matmul(
M, M,
N, N,
...@@ -291,13 +290,8 @@ if __name__ == "__main__": ...@@ -291,13 +290,8 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument( parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces")
"--with_roller", parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
type=bool,
default=False,
help="Whether to use roller to deduce search spaces")
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
......
...@@ -70,7 +70,8 @@ def get_configs(M, N, K): ...@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num, thread_num,
policy, policy,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -81,7 +82,8 @@ def get_configs(M, N, K): ...@@ -81,7 +82,8 @@ def get_configs(M, N, K):
"thread_num": c[4], "thread_num": c[4],
"policy": c[5], "policy": c[5],
"enable_rasterization": c[6], # keep param name for backward-compat "enable_rasterization": c[6], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
...@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): ...@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit(out_idx=[2],) @jit(
out_idx=[2],
)
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
...@@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): ...@@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor((M, K // 2), in_dtype), A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), in_dtype), B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): ...@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
""" """
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype) A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
...@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype): ...@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared: E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K), }
}) )
# Loop over sub-blocks in K dimension, pipelined by num_stages # Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared # Load a sub-block of A from global memory into A_shared
...@@ -241,18 +243,13 @@ if __name__ == "__main__": ...@@ -241,18 +243,13 @@ if __name__ == "__main__":
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true") parser.add_argument("--disable_cache", action="store_true")
parser.add_argument( parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument( parser.add_argument(
"--bench_torch_sparse", "--bench_torch_sparse",
type=str, type=str,
choices=['cutlass', 'cusparselt'], choices=["cutlass", "cusparselt"],
default=None, default=None,
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported" help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -274,7 +271,8 @@ if __name__ == "__main__": ...@@ -274,7 +271,8 @@ if __name__ == "__main__":
if args.bench_torch_sparse is not None: if args.bench_torch_sparse is not None:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
if args.bench_torch_sparse == 'cutlass':
if args.bench_torch_sparse == "cutlass":
SparseSemiStructuredTensor._FORCE_CUTLASS = True SparseSemiStructuredTensor._FORCE_CUTLASS = True
A_sp = to_sparse_semi_structured(A, transposed=False) A_sp = to_sparse_semi_structured(A, transposed=False)
torch_sparse_latency = do_bench(lambda: A_sp @ B) torch_sparse_latency = do_bench(lambda: A_sp @ B)
...@@ -285,8 +283,6 @@ if __name__ == "__main__": ...@@ -285,8 +283,6 @@ if __name__ == "__main__":
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
if args.bench_torch_sparse is not None: if args.bench_torch_sparse is not None:
print( print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
)
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}") print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
...@@ -104,9 +104,7 @@ def get_configs(args, kwargs): ...@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square], policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False], enable_rasteration=[True, False],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs return configs
...@@ -116,7 +114,9 @@ def get_configs(args, kwargs): ...@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
warmup=3, warmup=3,
rep=20, rep=20,
) )
@jit(out_idx=[2],) @jit(
out_idx=[2],
)
def matmul( def matmul(
M, M,
N, N,
...@@ -164,9 +164,9 @@ def matmul( ...@@ -164,9 +164,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -181,7 +181,6 @@ def matmul( ...@@ -181,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
......
...@@ -20,33 +20,27 @@ extensions = [ ...@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension", "autoapi.extension",
] ]
autoapi_type = 'python' autoapi_type = "python"
autoapi_dirs = ['../tilelang'] autoapi_dirs = ["../tilelang"]
autoapi_options = [ autoapi_options = [
'members', "members",
'undoc-members', "undoc-members",
'show-inheritance', "show-inheritance",
'show-module-summary', "show-module-summary",
'special-members', "special-members",
] ]
autoapi_keep_files = False # Useful for debugging the generated rst files autoapi_keep_files = False # Useful for debugging the generated rst files
autoapi_generate_api_docs = True autoapi_generate_api_docs = True
autodoc_typehints = 'description' autodoc_typehints = "description"
autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"] autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"]
source_suffix = { source_suffix = {".rst": "restructuredtext", ".md": "markdown"}
'.rst': 'restructuredtext',
'.md': 'markdown',
}
myst_enable_extensions = [ myst_enable_extensions = ["colon_fence", "deflist"]
"colon_fence",
"deflist",
]
redirects = {"get_started/try_out": "../index.html#getting-started"} redirects = {"get_started/try_out": "../index.html#getting-started"}
...@@ -66,10 +60,7 @@ html_css_files = ["custom.css"] ...@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
footer_copyright = "© 2025-2026 TileLang" footer_copyright = "© 2025-2026 TileLang"
footer_note = " " footer_note = " "
html_theme_options = { html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"}
"light_logo": "img/logo-v2.png",
"dark_logo": "img/logo-v2.png",
}
header_links = [ header_links = [
("Home", "https://github.com/tile-ai/tilelang"), ("Home", "https://github.com/tile-ai/tilelang"),
......
...@@ -11,22 +11,20 @@ import time ...@@ -11,22 +11,20 @@ import time
def ref_program(Q, K, V, is_causal, groups=1): def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size( assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1) dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2) K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2) V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float() lse = torch.logsumexp(scores, dim=-1).float()
return output, lse return output, lse
...@@ -45,23 +43,23 @@ def get_fwd_configs(): ...@@ -45,23 +43,23 @@ def get_fwd_configs():
valid_configs = [] valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
threads, num_stages, block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
enable_rasterization, k_pack, ):
panel_size, qk_coalesced_width, valid_configs.append(
v_coalesced_width): {
valid_configs.append({ "block_M": m,
"block_M": m, "block_N": n,
"block_N": n, "num_split_q": s,
"num_split_q": s, "threads": t,
"threads": t, "num_stages": stages,
"num_stages": stages, "enable_rasterization": r,
"enable_rasterization": r, "k_pack": k,
"k_pack": k, "panel_size": p,
"panel_size": p, "qk_coalesced_width": qkw,
"qk_coalesced_width": qkw, "v_coalesced_width": vw,
"v_coalesced_width": vw, }
}) )
return valid_configs return valid_configs
...@@ -85,7 +83,7 @@ def fast_flashattn( ...@@ -85,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width: int, qk_coalesced_width: int,
v_coalesced_width: int, v_coalesced_width: int,
): ):
scale = (1.0 / dim)**0.5 scale = (1.0 / dim) ** 0.5
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -97,11 +95,11 @@ def fast_flashattn( ...@@ -97,11 +95,11 @@ def fast_flashattn(
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype), LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
): ):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization) T.use_swizzle(panel_size, enable=enable_rasterization)
...@@ -135,33 +133,21 @@ def fast_flashattn( ...@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype) m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy( T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
loop_end_k = ( loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
T.ceildiv(q_block_offset +
block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
row_sum = T.alloc_fragment([block_M], accum_dtype) row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages): for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N kv_idx = k * block_N
T.copy( T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
K[bz, kv_idx:kv_idx + block_N, by // groups, :], T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(
...@@ -216,8 +202,7 @@ def fast_flashattn( ...@@ -216,8 +202,7 @@ def fast_flashattn(
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if q_block_offset + i < seq_len: if q_block_offset + i < seq_len:
lse_val = T.if_then_else(l_i[i] > 0, lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
LSE[bz, by, q_block_offset + i] = lse_val LSE[bz, by, q_block_offset + i] = lse_val
bx_loop_var = current_bx + num_split_q bx_loop_var = current_bx + num_split_q
...@@ -234,16 +219,17 @@ def get_bwd_configs(): ...@@ -234,16 +219,17 @@ def get_bwd_configs():
panel_size = [7, 8, 9, 10] panel_size = [7, 8, 9, 10]
configs = [] configs = []
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size):
enable_rasterization, panel_size): configs.append(
configs.append({ {
"block_M": m, "block_M": m,
"block_N": n, "block_N": n,
"num_stages": stages, "num_stages": stages,
"threads": t, "threads": t,
"enable_rasterization": r, "enable_rasterization": r,
"panel_size": p, "panel_size": p,
}) }
)
return configs return configs
...@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
blk = 32 blk = 32
@T.prim_func @T.prim_func
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by): with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype) do = T.alloc_fragment([blk, blk], dtype)
...@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True) @tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True)
@tilelang.jit @tilelang.jit
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int, def flashattn_bwd(
num_stages: int, threads: int, enable_rasterization: bool, panel_size: int): batch,
sm_scale = (1.0 / dim)**0.5 heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_stages: int,
threads: int,
enable_rasterization: bool,
panel_size: int,
):
sm_scale = (1.0 / dim) ** 0.5
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b ...@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
def flash_bwd_kernel(Q: T.Tensor(q_shape, def flash_bwd_kernel(
dtype), K: T.Tensor(kv_shape, Q: T.Tensor(q_shape, dtype),
dtype), V: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len], V: T.Tensor(kv_shape, dtype),
accum_dtype), dO: T.Tensor(q_shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], lse: T.Tensor([batch, heads, seq_len], accum_dtype),
accum_dtype), dQ: T.Tensor(q_shape, accum_dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)): dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype),
dV: T.Tensor(kv_shape, accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
T.use_swizzle(panel_size, enable=enable_rasterization) T.use_swizzle(panel_size, enable=enable_rasterization)
...@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b ...@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk = T.alloc_fragment([block_M, dim], accum_dtype) dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype) dq = T.alloc_fragment([block_N, dim], accum_dtype)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared) T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared) T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
...@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b ...@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed = T.ceildiv(seq_len, block_N) loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared) T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j]) P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j])
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0)
P_acc[i, j], 0.0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared) T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared)
T.clear(dP) T.clear(dP)
T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b ...@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T.copy(P_acc, p_cast) T.copy(P_acc, p_cast)
T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale
...@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)): def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy( T.copy(
dQ_in[bz, bx * blk:(bx + 1) * blk, by, :], dQ_in[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
) )
return flash_bwd_post return flash_bwd_post
...@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100): ...@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return np.median(times) return np.median(times)
def main(batch: int = 1, def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
device = "cuda" device = "cuda"
dtype = torch.float16 dtype = torch.float16
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
print( print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}")
f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}"
)
flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 5 * flops_per_gemm total_flops = 5 * flops_per_gemm
...@@ -517,22 +508,19 @@ def main(batch: int = 1, ...@@ -517,22 +508,19 @@ def main(batch: int = 1,
o_ref.backward(dO) o_ref.backward(dO)
print("Verifying backward pass correctness...") print("Verifying backward pass correctness...")
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
if dq_close: if dq_close:
print("dQ is correct.") print("dQ is correct.")
else: else:
print("dQ mismatch detected.") print("dQ mismatch detected.")
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
if dk_close: if dk_close:
print("dK is correct.") print("dK is correct.")
else: else:
print("dK mismatch detected.") print("dK mismatch detected.")
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
if dv_close: if dv_close:
print("dV is correct.") print("dV is correct.")
else: else:
...@@ -553,9 +541,7 @@ def main(batch: int = 1, ...@@ -553,9 +541,7 @@ def main(batch: int = 1,
torch.cuda.synchronize() torch.cuda.synchronize()
ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100) ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100)
print( print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops")
f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops"
)
def run_complete_fwd_bwd(): def run_complete_fwd_bwd():
o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v) o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v)
...@@ -593,12 +579,12 @@ def main(batch: int = 1, ...@@ -593,12 +579,12 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=8, help='heads') parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length') parser.add_argument("--seq_len", type=int, default=1024, help="sequence length")
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
...@@ -13,10 +13,10 @@ def supply_tensors_gpu(params): ...@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP.""" """Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = [] tensors = []
for param in params: for param in params:
if hasattr(param, 'shape') and hasattr(param, 'dtype'): if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device # Force creation on GPU device
shape = [int(s) for s in param.shape] shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device='cuda') tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
tensors.append(tensor) tensors.append(tensor)
else: else:
tensors.append(param) tensors.append(param)
...@@ -24,22 +24,20 @@ def supply_tensors_gpu(params): ...@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
def ref_program(Q, K, V, is_causal, groups=1): def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size( assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1) dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2) K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2) V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal: if is_causal:
seq_len = Q.size(1) seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0) mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -58,23 +56,23 @@ def get_configs(): ...@@ -58,23 +56,23 @@ def get_configs():
valid_configs = [] valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
threads, num_stages, block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
enable_rasterization, k_pack, ):
panel_size, qk_coalesced_width, valid_configs.append(
v_coalesced_width): {
valid_configs.append({ "block_M": m,
"block_M": m, "block_N": n,
"block_N": n, "num_split_q": s,
"num_split_q": s, "threads": t,
"threads": t, "num_stages": stages,
"num_stages": stages, "enable_rasterization": r,
"enable_rasterization": r, "k_pack": k,
"k_pack": k, "panel_size": p,
"panel_size": p, "qk_coalesced_width": qkw,
"qk_coalesced_width": qkw, "v_coalesced_width": vw,
"v_coalesced_width": vw, }
}) )
return valid_configs return valid_configs
...@@ -98,7 +96,7 @@ def fast_flashattn( ...@@ -98,7 +96,7 @@ def fast_flashattn(
qk_coalesced_width: int, qk_coalesced_width: int,
v_coalesced_width: int, v_coalesced_width: int,
): ):
scale = (1.0 / dim)**0.5 scale = (1.0 / dim) ** 0.5
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
...@@ -110,10 +108,10 @@ def fast_flashattn( ...@@ -110,10 +108,10 @@ def fast_flashattn(
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization) T.use_swizzle(panel_size, enable=enable_rasterization)
...@@ -147,32 +145,21 @@ def fast_flashattn( ...@@ -147,32 +145,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype) m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype) scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy( T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M, loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype) row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages): for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N kv_idx = k * block_N
T.copy( T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
K[bz, kv_idx:kv_idx + block_N, by // groups, :], T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(
...@@ -222,13 +209,7 @@ def fast_flashattn( ...@@ -222,13 +209,7 @@ def fast_flashattn(
return main return main
def main(batch: int = 1, def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
if is_causal: if is_causal:
...@@ -250,18 +231,16 @@ def main(batch: int = 1, ...@@ -250,18 +231,16 @@ def main(batch: int = 1,
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100) latency = profiler.do_bench(warmup=100)
print( print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=8, help='heads') parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
...@@ -25,22 +25,7 @@ def check_hopper(): ...@@ -25,22 +25,7 @@ def check_hopper():
return False return False
def kernel(N, def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
...@@ -50,13 +35,11 @@ def kernel(N, ...@@ -50,13 +35,11 @@ def kernel(N,
@T.prim_func @T.prim_func
def conv( def conv(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -65,11 +48,13 @@ def kernel(N, ...@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ T.annotate_layout(
out_shared: make_swizzled_layout(out_shared), {
data_shared: make_swizzled_layout(data_shared), out_shared: make_swizzled_layout(out_shared),
kernel_shared: make_swizzled_layout(kernel_shared), data_shared: make_swizzled_layout(data_shared),
}) kernel_shared: make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -81,10 +66,8 @@ def kernel(N, ...@@ -81,10 +66,8 @@ def kernel(N,
m = by * block_M + i m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
(access_w < W)) data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local) T.gemm(data_shared, kernel_shared, out_local)
......
...@@ -20,9 +20,9 @@ def kernel( ...@@ -20,9 +20,9 @@ def kernel(
@T.prim_func @T.prim_func
def matmul( def matmul(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -51,8 +51,7 @@ def triton_kernel( ...@@ -51,8 +51,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH: if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else: else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
...@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T ...@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size, BANDWIDTH=window_size,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q) start_q=seq_kv - seq_q,
)
return o return o
...@@ -137,12 +137,11 @@ def main( ...@@ -137,12 +137,11 @@ def main(
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min( flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
...@@ -170,15 +169,14 @@ def main( ...@@ -170,15 +169,14 @@ def main(
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads, threads=threads,
dtype=dtype) dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose( if torch.allclose(
triton_program(Q, K, V, sinks, window_size), triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), ):
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅") print("Checks for triton passed.✅")
else: else:
print("Checks for triton failed.❌") print("Checks for triton failed.❌")
...@@ -198,20 +196,14 @@ def main( ...@@ -198,20 +196,14 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='heads') parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--groups', type=int, default=8, help='groups') parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int, parser.add_argument("--tune", action="store_true", help="tune configs")
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
args.dtype, args.tune)
...@@ -50,8 +50,7 @@ def triton_kernel( ...@@ -50,8 +50,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM]) q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH: if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else: else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
...@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T ...@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size, BANDWIDTH=window_size,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q) start_q=seq_kv - seq_q,
)
return o return o
def main(batch: int = 1, def main(
heads: int = 32, batch: int = 1,
seq_q: int = 256, heads: int = 32,
seq_kv: int = 256, seq_q: int = 256,
dim: int = 128, seq_kv: int = 256,
window_size: Optional[int] = None, dim: int = 128,
dtype: str = "float16", window_size: Optional[int] = None,
tune: bool = False): dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min( flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
...@@ -163,15 +164,14 @@ def main(batch: int = 1, ...@@ -163,15 +164,14 @@ def main(batch: int = 1,
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads, threads=threads,
dtype=dtype) dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), )
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500) latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
...@@ -184,19 +184,13 @@ def main(batch: int = 1, ...@@ -184,19 +184,13 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query') parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value') parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int, parser.add_argument("--tune", action="store_true", help="tune")
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
args.tune)
...@@ -20,28 +20,30 @@ def get_bwd_configs(): ...@@ -20,28 +20,30 @@ def get_bwd_configs():
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd( def flashattn_fwd(
batch, batch,
heads, heads,
seq_len, seq_len,
dim, dim,
groups=1, groups=1,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None, sm_scale=None,
block_M=64, block_M=64,
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16"): dtype: str = "float16",
):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -51,12 +53,12 @@ def flashattn_fwd( ...@@ -51,12 +53,12 @@ def flashattn_fwd(
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore
Output: T.Tensor(q_shape, dtype), # type: ignore Output: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore Sinks: T.Tensor([heads], dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -73,7 +75,7 @@ def flashattn_fwd( ...@@ -73,7 +75,7 @@ def flashattn_fwd(
sinks = T.alloc_fragment([heads], dtype) sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -81,22 +83,20 @@ def flashattn_fwd( ...@@ -81,22 +83,20 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.max(0, start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
k_idx = k * block_N + j k_idx = k * block_N + j
if window_size is not None: if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
0, -T.infinity(acc_s.dtype))
else: else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -106,8 +106,7 @@ def flashattn_fwd( ...@@ -106,8 +106,7 @@ def flashattn_fwd(
# NOTE(wt): check_inf is necessary for sliding window attention. # NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if window_size is not None: if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -124,22 +123,23 @@ def flashattn_fwd( ...@@ -124,22 +123,23 @@ def flashattn_fwd(
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -147,9 +147,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") ...@@ -147,9 +147,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") ...@@ -158,26 +158,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -185,32 +186,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" ...@@ -185,32 +186,27 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy( T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :], dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
) )
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
def flashattn_bwd(batch, }
heads, )
seq_len, def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention
dim,
groups,
window_size=None,
sm_scale=None,
dtype="float16"): # None for full attention
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -225,15 +221,15 @@ def flashattn_bwd(batch, ...@@ -225,15 +221,15 @@ def flashattn_bwd(batch,
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(kv_shape, dtype), # type: ignore K: T.Tensor(kv_shape, dtype), # type: ignore
V: T.Tensor(kv_shape, dtype), # type: ignore V: T.Tensor(kv_shape, dtype), # type: ignore
dO: T.Tensor(q_shape, dtype), # type: ignore dO: T.Tensor(q_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(kv_shape, accum_dtype), # type: ignore dK: T.Tensor(kv_shape, accum_dtype), # type: ignore
dV: T.Tensor(kv_shape, accum_dtype), # type: ignore dV: T.Tensor(kv_shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -253,44 +249,47 @@ def flashattn_bwd(batch, ...@@ -253,44 +249,47 @@ def flashattn_bwd(batch,
dv_shared = T.alloc_shared([block_M, dim], accum_dtype) dv_shared = T.alloc_shared([block_M, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype) dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
T.copy(K[bz, bx // groups, by * block_M:(by + 1) * block_M, :], K_shared) }
T.copy(V[bz, bx // groups, by * block_M:(by + 1) * block_M, :], V_shared) )
T.copy(K[bz, bx // groups, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx // groups, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.min( loop_ed = (
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) if window_size is not None
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
if window_size is not None: if window_size is not None:
qkT[i, j] = T.if_then_else( qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) )
else: else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0) T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -299,12 +298,12 @@ def flashattn_bwd(batch, ...@@ -299,12 +298,12 @@ def flashattn_bwd(batch,
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.atomic_add(dV[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dv_shared) T.atomic_add(dV[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.atomic_add(dK[bz, bx // groups, by * block_M:(by + 1) * block_M, :], dk_shared) T.atomic_add(dK[bz, bx // groups, by * block_M : (by + 1) * block_M, :], dk_shared)
return flash_bwd return flash_bwd
...@@ -316,10 +315,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" ...@@ -316,10 +315,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"
@T.prim_func @T.prim_func
def flash_bwd_dsink( def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, dtype), # type: ignore dsinks: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=256) as (bx, by, bz):
sink = T.alloc_local([1], dtype) sink = T.alloc_local([1], dtype)
...@@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16" ...@@ -328,21 +327,18 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"
dsink_fragment = T.alloc_fragment([block], dtype) dsink_fragment = T.alloc_fragment([block], dtype)
sink[0] = Sinks[bx] sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block): for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
lse_fragment[i]) * delta_fragment[i] T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block])
return flash_bwd_dsink return flash_bwd_dsink
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sinks, window_size, groups): def forward(ctx, q, k, v, sinks, window_size, groups):
def maybe_contiguous(x): def maybe_contiguous(x):
if x.stride(-1) != 1: if x.stride(-1) != 1:
return x.contiguous() return x.contiguous()
...@@ -388,13 +384,14 @@ attention = _attention.apply ...@@ -388,13 +384,14 @@ attention = _attention.apply
# Adapted and optimized from # Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor, def ref_program(
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
sinks: torch.Tensor, value: torch.Tensor,
sliding_window: Optional[int] = None, sinks: torch.Tensor,
dtype: torch.dtype = torch.float16) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape batch_size, num_keys, num_key_value_heads, head_dim = key.shape
...@@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor, ...@@ -430,32 +427,31 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1, def main(
H: int = 8, BATCH: int = 1,
N_CTX: int = 512, H: int = 8,
D_HEAD: int = 64, N_CTX: int = 512,
groups: int = 2, D_HEAD: int = 64,
window_size: Optional[int] = None, groups: int = 2,
dtype: str = "float16"): window_size: Optional[int] = None,
dtype: str = "float16",
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= N_CTX assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min( flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
K = torch.randn( K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
V = torch.randn_like(K).requires_grad_() V = torch.randn_like(K).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device="cuda").requires_grad_()
dO = torch.randn_like(Q) dO = torch.randn_like(Q)
...@@ -479,16 +475,11 @@ def main(BATCH: int = 1, ...@@ -479,16 +475,11 @@ def main(BATCH: int = 1,
"float16": (1e-2, 1e-2), "float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2), "bfloat16": (2e-2, 2e-2),
}[dtype] }[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose( assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
assert torch.allclose( assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅") print("All checks passed for tilelang kernels.✅")
...@@ -509,17 +500,12 @@ def main(BATCH: int = 1, ...@@ -509,17 +500,12 @@ def main(BATCH: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size') parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument('--h', type=int, default=64, help='Number of heads') parser.add_argument("--h", type=int, default=64, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
parser.add_argument('--d_head', type=int, default=128, help='Head dimension') parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
parser.add_argument('--groups', type=int, default=8, help='Groups') parser.add_argument("--groups", type=int, default=8, help="Groups")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype)
...@@ -23,9 +23,11 @@ def get_configs(): ...@@ -23,9 +23,11 @@ def get_configs():
rep=100, rep=100,
) )
@tilelang.jit( @tilelang.jit(
out_idx=[3], pass_configs={ out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn( def flashattn(
batch, batch,
heads, heads,
...@@ -41,12 +43,11 @@ def flashattn( ...@@ -41,12 +43,11 @@ def flashattn(
threads=256, threads=256,
dtype: str = "float16", dtype: str = "float16",
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
...@@ -68,13 +69,12 @@ def flashattn( ...@@ -68,13 +69,12 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j k_idx = k * block_N + j
if window_size is not None: if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -89,18 +89,18 @@ def flashattn( ...@@ -89,18 +89,18 @@ def flashattn(
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -112,8 +112,7 @@ def flashattn( ...@@ -112,8 +112,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention. # NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if window_size is not None: if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -128,19 +127,19 @@ def flashattn( ...@@ -128,19 +127,19 @@ def flashattn(
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype), Sinks: T.Tensor([heads], dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -157,58 +156,58 @@ def flashattn( ...@@ -157,58 +156,58 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype) sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({ T.annotate_layout(
Q_shared: make_swizzled_layout(Q_shared), {
K_shared: make_swizzled_layout(K_shared), Q_shared: make_swizzled_layout(Q_shared),
V_shared: make_swizzled_layout(V_shared), K_shared: make_swizzled_layout(K_shared),
O_shared: make_swizzled_layout(O_shared), V_shared: make_swizzled_layout(V_shared),
}) O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min( end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) // start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
block_N) if window_size is not None else 0
for k in T.Pipelined( for k in T.Pipelined(
start, start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
# Following functions are adapted and optimized from # Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor, def ref_program(
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
sinks: torch.Tensor, value: torch.Tensor,
sliding_window: Optional[int] = None, sinks: torch.Tensor,
dtype: torch.dtype = torch.float16) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape batch_size, num_keys, num_key_value_heads, head_dim = key.shape
...@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor, ...@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def gen_inputs( def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
H, key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
Sq, value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
Skv, sinks = torch.randn([H], dtype=dtype, device="cuda")
D,
groups,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
return query, key, value, sinks return query, key, value, sinks
...@@ -277,12 +268,11 @@ def main( ...@@ -277,12 +268,11 @@ def main(
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min( flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5 flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul total_flops = 2 * flops_per_matmul
...@@ -310,15 +300,14 @@ def main( ...@@ -310,15 +300,14 @@ def main(
block_N=block_N, block_N=block_N,
num_stages=num_stages, num_stages=num_stages,
threads=threads, threads=threads,
dtype=dtype) dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
torch.testing.assert_close( torch.testing.assert_close(
kernel(Q, K, V, sinks), kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), )
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅") print("All checks passed.✅")
# Benchmark tilelang # Benchmark tilelang
...@@ -329,20 +318,14 @@ def main( ...@@ -329,20 +318,14 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=64, help='heads') parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query') parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value') parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--groups', type=int, default=8, help='groups') parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int, parser.add_argument("--tune", action="store_true", help="tune configs")
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
args.dtype, args.tune)
...@@ -20,27 +20,29 @@ def get_bwd_configs(): ...@@ -20,27 +20,29 @@ def get_bwd_configs():
@tilelang.jit( @tilelang.jit(
out_idx=[3, 4], pass_configs={ out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_fwd( def flashattn_fwd(
batch, batch,
heads, heads,
seq_len, seq_len,
dim, dim,
window_size=None, # None for full attention, window_size=None, # None for full attention,
sm_scale=None, sm_scale=None,
block_M=64, block_M=64,
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16"): dtype: str = "float16",
):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -48,12 +50,12 @@ def flashattn_fwd( ...@@ -48,12 +50,12 @@ def flashattn_fwd(
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Sinks: T.Tensor([heads], dtype), # type: ignore Sinks: T.Tensor([heads], dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -70,7 +72,7 @@ def flashattn_fwd( ...@@ -70,7 +72,7 @@ def flashattn_fwd(
sinks = T.alloc_fragment([heads], dtype) sinks = T.alloc_fragment([heads], dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -78,22 +80,20 @@ def flashattn_fwd( ...@@ -78,22 +80,20 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.max(0, start = T.max(0, (bx * block_M - window_size) // block_N) if window_size is not None else 0
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start, end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
k_idx = k * block_N + j k_idx = k * block_N + j
if window_size is not None: if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
0, -T.infinity(acc_s.dtype))
else: else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
...@@ -103,8 +103,7 @@ def flashattn_fwd( ...@@ -103,8 +103,7 @@ def flashattn_fwd(
# NOTE(wt): check_inf is necessary for sliding window attention. # NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
if window_size is not None: if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -121,22 +120,23 @@ def flashattn_fwd( ...@@ -121,22 +120,23 @@ def flashattn_fwd(
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 - logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(acc_o, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd return flash_fwd
@tilelang.jit( @tilelang.jit(
out_idx=[2], pass_configs={ out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -144,9 +144,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") ...@@ -144,9 +144,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16") ...@@ -155,26 +155,27 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16")
delta = T.alloc_fragment([blk], accum_dtype) delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc) T.clear(acc)
for k in range(T.ceildiv(dim, blk)): for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], o) T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, bx, by * blk:(by + 1) * blk, k * blk:(k + 1) * blk], do) T.copy(dO[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk): for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep return flash_bwd_prep
def make_dq_layout(dQ): def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit( @tilelang.jit(
out_idx=[1], pass_configs={ out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) },
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float" accum_dtype = "float"
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -182,22 +183,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" ...@@ -182,22 +183,24 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy( T.copy(
dQ[bz, by, bx * blk:(bx + 1) * blk, :], dQ[bz, by, bx * blk : (bx + 1) * blk, :],
dQ_out[bz, by, bx * blk:(bx + 1) * blk, :], dQ_out[bz, by, bx * blk : (bx + 1) * blk, :],
) )
return flash_bwd_post return flash_bwd_post
@tilelang.jit(pass_configs={ @tilelang.jit(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, pass_configs={
}) tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd( def flashattn_bwd(
batch, batch,
heads, heads,
...@@ -207,11 +210,10 @@ def flashattn_bwd( ...@@ -207,11 +210,10 @@ def flashattn_bwd(
sm_scale=None, sm_scale=None,
dtype: str = "float16", dtype: str = "float16",
): ):
block_M, block_N, num_stages, threads = get_bwd_configs() block_M, block_N, num_stages, threads = get_bwd_configs()
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
...@@ -222,15 +224,15 @@ def flashattn_bwd( ...@@ -222,15 +224,15 @@ def flashattn_bwd(
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -254,43 +256,46 @@ def flashattn_bwd( ...@@ -254,43 +256,46 @@ def flashattn_bwd(
dv_shared = T.alloc_shared([block_M, dim], dtype) dv_shared = T.alloc_shared([block_M, dim], dtype)
dk_shared = T.alloc_shared([block_M, dim], dtype) dk_shared = T.alloc_shared([block_M, dim], dtype)
T.annotate_layout({ T.annotate_layout(
dQ: make_dq_layout(dQ), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
}) dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
T.copy(K[bz, bx, by * block_M:(by + 1) * block_M, :], K_shared) }
T.copy(V[bz, bx, by * block_M:(by + 1) * block_M, :], V_shared) )
T.copy(K[bz, bx, by * block_M : (by + 1) * block_M, :], K_shared)
T.copy(V[bz, bx, by * block_M : (by + 1) * block_M, :], V_shared)
T.clear(dv) T.clear(dv)
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.min( loop_ed = (
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv( T.min(T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(seq_len, block_N))
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N) if window_size is not None
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N : (k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
if window_size is not None: if window_size is not None:
qkT[i, j] = T.if_then_else( qkT[i, j] = T.if_then_else(
by * block_M + i <= k * block_N + j and by * block_M + i <= k * block_N + j and by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0
by * block_M + i > k * block_N + j - window_size, qkT[i, j], 0) )
else: else:
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
0) T.copy(dO[bz, bx, k * block_N : (k + 1) * block_N, :], dst=do)
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], dst=do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow) T.gemm(qkT_cast, B=do, C=dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
...@@ -299,12 +304,12 @@ def flashattn_bwd( ...@@ -299,12 +304,12 @@ def flashattn_bwd(
T.copy(dsT_cast, dsT_shared) T.copy(dsT_cast, dsT_shared)
T.clear(dq) T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True) T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.atomic_add(dQ[bz, bx, k * block_N:(k + 1) * block_N, :], dq) T.atomic_add(dQ[bz, bx, k * block_N : (k + 1) * block_N, :], dq)
T.copy(dv, dv_shared) T.copy(dv, dv_shared)
T.copy(dk, dk_shared) T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dv_shared, dV[bz, bx, by * block_M : (by + 1) * block_M, :])
T.copy(dk_shared, dK[bz, bx, by * block_M:(by + 1) * block_M, :]) T.copy(dk_shared, dK[bz, bx, by * block_M : (by + 1) * block_M, :])
return flash_bwd return flash_bwd
...@@ -316,10 +321,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" ...@@ -316,10 +321,10 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"
@T.prim_func @T.prim_func
def flash_bwd_dsink( def flash_bwd_dsink(
Sinks: T.Tensor([heads], dtype), # type: ignore Sinks: T.Tensor([heads], dtype), # type: ignore
Delta: T.Tensor(shape, accum_dtype), # type: ignore Delta: T.Tensor(shape, accum_dtype), # type: ignore
lse: T.Tensor(shape, accum_dtype), # type: ignore lse: T.Tensor(shape, accum_dtype), # type: ignore
dsinks: T.Tensor(shape, accum_dtype), # type: ignore dsinks: T.Tensor(shape, accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block), batch, threads=128) as (bx, by, bz):
sink = T.alloc_local([1], dtype) sink = T.alloc_local([1], dtype)
...@@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16" ...@@ -328,18 +333,16 @@ def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"
dsink_fragment = T.alloc_fragment([block], accum_dtype) dsink_fragment = T.alloc_fragment([block], accum_dtype)
sink[0] = Sinks[bx] sink[0] = Sinks[bx]
T.copy(lse[bz, bx, by * block:(by + 1) * block], lse_fragment) T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment)
T.copy(Delta[bz, bx, by * block:(by + 1) * block], delta_fragment) T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment)
for i in T.Parallel(block): for i in T.Parallel(block):
dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - dsink_fragment[i] = -T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * delta_fragment[i]
lse_fragment[i]) * delta_fragment[i] T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block])
T.copy(dsink_fragment, dsinks[bz, bx, by * block:(by + 1) * block])
return flash_bwd_dsink return flash_bwd_dsink
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sinks, window_size): def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
...@@ -383,15 +386,15 @@ attention = _attention.apply ...@@ -383,15 +386,15 @@ attention = _attention.apply
# Adapted and optimized from # Adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py # https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor, def ref_program(
key: torch.Tensor, query: torch.Tensor,
value: torch.Tensor, key: torch.Tensor,
sinks: torch.Tensor, value: torch.Tensor,
sliding_window: Optional[int] = None, sinks: torch.Tensor,
dtype: torch.dtype = torch.float16) -> torch.Tensor: sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
query = query.transpose(1, 2).contiguous().unsqueeze( ) -> torch.Tensor:
3) # align with the original function's interface query = query.transpose(1, 2).contiguous().unsqueeze(3) # align with the original function's interface
key = key.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous()
...@@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor, ...@@ -426,29 +429,22 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
head_dim).to(dtype)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1, def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"):
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None: if window_size is not None:
print('Using sliding window attention.') print("Using sliding window attention.")
assert window_size <= N_CTX assert window_size <= N_CTX
flops_per_matmul = 2.0 * BATCH * H * min( flops_per_matmul = 2.0 * BATCH * H * min(window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
window_size, N_CTX // 2) * N_CTX * D_HEAD # just a rough estimation
else: else:
print('Using full attention.') print("Using full attention.")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5 flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD * 0.5
total_flops = 5 * flops_per_matmul total_flops = 5 * flops_per_matmul
Q = (torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()) Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda").requires_grad_()
K = torch.randn_like(Q).requires_grad_() K = torch.randn_like(Q).requires_grad_()
V = torch.randn_like(Q).requires_grad_() V = torch.randn_like(Q).requires_grad_()
sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_() sinks = torch.randn(H, dtype=torch_dtype, device=Q.device).requires_grad_()
...@@ -473,16 +469,11 @@ def main(BATCH: int = 1, ...@@ -473,16 +469,11 @@ def main(BATCH: int = 1,
"float16": (1e-2, 1e-2), "float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2), "bfloat16": (2e-2, 2e-2),
}[dtype] }[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f'O max err: {(O-O_ref).abs().max()}' assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose( assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
dV, dV_ref, rtol=rtol, atol=atol), f'dV max err: {(dV-dV_ref).abs().max()}' assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}"
assert torch.allclose( assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dq max err: {(dQ - dQ_ref).abs().max()}"
dK, dK_ref, rtol=rtol, atol=atol), f'dK max err: {(dK-dK_ref).abs().max()}' assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}"
assert torch.allclose(
dQ, dQ_ref, rtol=rtol, atol=atol), f'dq max err: {(dQ-dQ_ref).abs().max()}'
assert torch.allclose(
dsinks, dsinks_ref, rtol=rtol,
atol=atol), f'dsinks max err: {(dsinks-dsinks_ref).abs().max()}'
print("All checks passed for tilelang kernels.✅") print("All checks passed for tilelang kernels.✅")
...@@ -503,16 +494,11 @@ def main(BATCH: int = 1, ...@@ -503,16 +494,11 @@ def main(BATCH: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='Batch size') parser.add_argument("--batch", type=int, default=1, help="Batch size")
parser.add_argument('--h', type=int, default=64, help='Number of heads') parser.add_argument("--h", type=int, default=64, help="Number of heads")
parser.add_argument('--n_ctx', type=int, default=4096, help='Context size') parser.add_argument("--n_ctx", type=int, default=4096, help="Context size")
parser.add_argument('--d_head', type=int, default=128, help='Head dimension') parser.add_argument("--d_head", type=int, default=128, help="Head dimension")
parser.add_argument( parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
'--window_size', parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype) main(args.batch, args.h, args.n_ctx, args.d_head, args.window_size, args.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment