Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -39,19 +39,9 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7 # sync with requirements-lint.txt
hooks:
- id: ruff-format
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/google/yapf
rev: v0.43.0 # sync with requirements-lint.txt
hooks:
- id: yapf
name: yapf-multiproc-bugfix
# yapf is not multiprocess safe, so we run a dummy yapf first.
args: [--in-place, docs/conf.py]
always_run: true
pass_filenames: false
- id: yapf
args: [--recursive, --in-place]
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1 # sync with requirements-lint.txt
hooks:
......@@ -62,4 +52,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$|
^.*\brequirements\b.*\.txt$
)
\ No newline at end of file
)
......@@ -7,10 +7,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -28,15 +25,15 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
import flash_attn
......
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -39,7 +36,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64
num_stages = 2
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
......@@ -48,7 +45,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -60,11 +56,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -79,18 +74,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -116,22 +111,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -146,7 +140,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -155,20 +149,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
if block_mask[k]:
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
......@@ -177,26 +170,23 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
program = blocksparse_flashattn(
BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=4)
def benchmark_fn():
......
......@@ -10,10 +10,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -31,39 +28,37 @@ def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=F
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
def benchmark_fn():
# Compute reference
# Expand block mask to full attention matrix
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda"))
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
# PyTorch reference implementation
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float('-inf'))
attn = torch.einsum("bhsd,bhtd->bhst", q, k) * sm_scale
attn = attn.masked_fill(~full_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
ref_output = torch.einsum("bhst,bhtd->bhsd", attn, v)
return ref_output
ref_latency = do_bench(
......
......@@ -15,10 +15,7 @@ def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
bsz, num_head, downsample_len, _ = x.shape
# N_CTX = downsample_len * BLOCK
sparse_index = torch.topk(x, topk, dim=-1).indices
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
False,
dtype=torch.bool,
device=x.device)
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
dense_mask.scatter_(-1, sparse_index, True)
if use_dense_for_last_block:
dense_mask[:, :, -2:, :] = True
......@@ -56,7 +53,6 @@ def _fwd_kernel_inner(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
if mask_val == True:
......@@ -72,8 +68,7 @@ def _fwd_kernel_inner(
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0,
float('-inf'))
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
......@@ -153,7 +148,7 @@ def _fwd_kernel(
v_ptrs = V + off_v
mask_ptrs = block_mask_ptr + start_m * stride_bmm
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
......@@ -191,24 +186,12 @@ def _fwd_kernel(
acc = acc * l_recip
acc = acc.to(Out.dtype.element_ty)
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[
None, :] * stride_od
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX)
def _forward(ctx,
q,
k,
v,
block_sparse_mask,
sm_scale,
BLOCK_M=64,
BLOCK_N=64,
num_warps=None,
num_stages=1,
out=None):
def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
assert k.shape[2] == v.shape[2]
o = out if out is not None else torch.empty_like(q).contiguous()
......@@ -253,7 +236,6 @@ def _forward(ctx,
class _sparse_attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_sparse_dense, sm_scale):
# shape constraints
......@@ -271,24 +253,22 @@ block_sparse_triton_fn = _sparse_attention.apply
def benchmark_topk_sparse_attention():
from benchmark_configs import configs
torch.manual_seed(0)
# Config
for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs:
# Create inputs
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16)
sm_scale = 1.0 / (D_HEAD**0.5)
# Create sparse mask (downsampled to block level)
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
device='cuda',
dtype=torch.bfloat16)
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device="cuda", dtype=torch.bfloat16)
x_ds[:, :, :, 0] = 100
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
......
......@@ -51,14 +51,15 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(
torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(
C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
out_prev = (
torch.einsum("bclhn,bchpn->bclhp", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out
)
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
......@@ -74,7 +75,6 @@ def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
@helion.kernel()
def helion_mamba2_chunk_scan_kernel(
cb: torch.Tensor,
......@@ -118,8 +118,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
dtype = cb.dtype
accum_dtype = torch.float32
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype ==
dtype)
assert x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == dtype
out = torch.empty_like(x)
......@@ -127,11 +126,10 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
[nheads, chunk_size, headdim, batch, nchunks],
block_size=[1, block_m, block_n, 1, 1],
block_size=[1, block_m, block_n, 1, 1],
):
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_m].to(torch.float32)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_m].to(torch.float32)
scale_m_local = torch.exp2(dA_cumsum_local_m * p)
C_local = C[
......@@ -152,10 +150,8 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
tile_m,
tile_k,
]
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_k].to(torch.float32)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
dA_cumsum_local_k[None, :] * p)
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - dA_cumsum_local_k[None, :] * p)
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local = (cb_local * dt_local[None, :]).to(dtype)
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
......@@ -169,11 +165,9 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
acc_o = hl.dot(cb_local, x_local, acc=acc_o)
D_local = D[tile_h.begin].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n].to(torch.float32)
acc_o += x_residual * D_local
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n] = acc_o.to(dtype=dtype)
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, tile_n] = acc_o.to(dtype=dtype)
return out
......@@ -182,12 +176,7 @@ def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
def get_configs():
iter_params = dict(
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
iter_params = dict(block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
......@@ -198,19 +187,21 @@ def get_configs():
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def chunk_scan_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
def chunk_scan_fwd(
batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128,
):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
......@@ -218,20 +209,20 @@ def chunk_scan_fwd(batch,
@T.prim_func
def main(
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
with T.Kernel(nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (
bz,
bx,
by,
):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
......@@ -257,27 +248,32 @@ def chunk_scan_fwd(batch,
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.annotate_layout(
{
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared),
}
)
T.no_set_max_nreg()
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M : (m_idx + 1) * block_M], dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
C[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz // (nheads // ngroups),
0:block_Dstate,
],
C_shared,
)
T.copy(prev_states[batch_idx, chunk_idx, bz, n_idx * block_N : (n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
......@@ -286,34 +282,47 @@ def chunk_scan_fwd(batch,
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
cb[
batch_idx,
chunk_idx,
bz // (nheads // ngroups),
m_idx * block_M : (m_idx + 1) * block_M,
k * block_K : (k + 1) * block_K,
],
cb_shared,
)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K : (k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
x[
batch_idx,
chunk_idx * chunk_size + k * block_K : chunk_idx * chunk_size + (k + 1) * block_K,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_shared,
)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
x[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
x_residual_shared,
)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
......@@ -321,24 +330,37 @@ def chunk_scan_fwd(batch,
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
Output[
batch_idx,
chunk_idx * chunk_size + m_idx * block_M : chunk_idx * chunk_size + (m_idx + 1) * block_M,
bz,
n_idx * block_N : (n_idx + 1) * block_N,
],
)
return main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=80, help='heads')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--dstate', type=int, default=128, help='dstate')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=80, help="heads")
parser.add_argument("--groups", type=int, default=1, help="groups")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--chunk_size", type=int, default=256, help="chunk size")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--dstate", type=int, default=128, help="dstate")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
batch, heads, groups, seq_len, chunk_size, dim, dstate = (
args.batch,
args.heads,
args.groups,
args.seq_len,
args.chunk_size,
args.dim,
args.dstate,
)
nchunks = math.ceil(seq_len / chunk_size)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
......@@ -360,8 +382,7 @@ if __name__ == "__main__":
D = torch.randn(heads).half().cuda()
print("Benchmarking Triton...")
triton_latency = do_bench(
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
triton_latency = do_bench(lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
print("Benchmarking Helion...")
......
......@@ -6,6 +6,7 @@ import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
......@@ -101,9 +102,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -112,7 +111,9 @@ def get_configs(args, kwargs):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def matmul(
M,
N,
......@@ -159,9 +160,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -176,7 +177,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......
......@@ -6,7 +6,8 @@ import tilelang as tl
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune
import itertools
......@@ -103,12 +104,11 @@ def tl_matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -116,10 +116,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10, enable=enable_rasteration)
......@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
......@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for config in configs:
print(config)
else:
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
......@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -247,7 +244,9 @@ def get_configs(args, kwargs):
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
@tl.jit(
out_idx=[2],
)
def matmul(
M,
N,
......@@ -291,13 +290,8 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
type=bool,
default=False,
help="Whether to use roller to deduce search spaces")
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces")
parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
......
......@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num,
policy,
enable_rasterization,
))
)
)
configs = [
{
......@@ -81,7 +82,8 @@ def get_configs(M, N, K):
"thread_num": c[4],
"policy": c[5],
"enable_rasterization": c[6], # keep param name for backward-compat
} for c in _configs
}
for c in _configs
]
return configs
......@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def kernel(
block_M=None,
block_N=None,
......@@ -165,10 +169,10 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype),
A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -182,9 +186,7 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......@@ -201,12 +203,12 @@ def matmul_sp(M, N, K, in_dtype, accum_dtype):
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared:
make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
......@@ -241,18 +243,13 @@ if __name__ == "__main__":
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
choices=['cutlass', 'cusparselt'],
choices=["cutlass", "cusparselt"],
default=None,
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported",
)
args = parser.parse_args()
......@@ -274,7 +271,8 @@ if __name__ == "__main__":
if args.bench_torch_sparse is not None:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
if args.bench_torch_sparse == 'cutlass':
if args.bench_torch_sparse == "cutlass":
SparseSemiStructuredTensor._FORCE_CUTLASS = True
A_sp = to_sparse_semi_structured(A, transposed=False)
torch_sparse_latency = do_bench(lambda: A_sp @ B)
......@@ -285,8 +283,6 @@ if __name__ == "__main__":
print(f"Best config: {best_config}")
if args.bench_torch_sparse is not None:
print(
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
)
print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
......@@ -104,9 +104,7 @@ def get_configs(args, kwargs):
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -116,7 +114,9 @@ def get_configs(args, kwargs):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def matmul(
M,
N,
......@@ -164,9 +164,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -181,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......
......@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension",
]
autoapi_type = 'python'
autoapi_dirs = ['../tilelang']
autoapi_type = "python"
autoapi_dirs = ["../tilelang"]
autoapi_options = [
'members',
'undoc-members',
'show-inheritance',
'show-module-summary',
'special-members',
"members",
"undoc-members",
"show-inheritance",
"show-module-summary",
"special-members",
]
autoapi_keep_files = False # Useful for debugging the generated rst files
autoapi_generate_api_docs = True
autodoc_typehints = 'description'
autodoc_typehints = "description"
autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"]
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
source_suffix = {".rst": "restructuredtext", ".md": "markdown"}
myst_enable_extensions = [
"colon_fence",
"deflist",
]
myst_enable_extensions = ["colon_fence", "deflist"]
redirects = {"get_started/try_out": "../index.html#getting-started"}
......@@ -66,10 +60,7 @@ html_css_files = ["custom.css"]
footer_copyright = "© 2025-2026 TileLang"
footer_note = " "
html_theme_options = {
"light_logo": "img/logo-v2.png",
"dark_logo": "img/logo-v2.png",
}
html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"}
header_links = [
("Home", "https://github.com/tile-ai/tilelang"),
......
......@@ -11,22 +11,20 @@ import time
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float()
return output, lse
......@@ -45,23 +43,23 @@ def get_fwd_configs():
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
......@@ -85,7 +83,7 @@ def fast_flashattn(
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
......@@ -97,11 +95,11 @@ def fast_flashattn(
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
LSE: T.Tensor([batch, heads, seq_len], accum_dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -135,33 +133,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = (
T.ceildiv(q_block_offset +
block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
......@@ -216,8 +202,7 @@ def fast_flashattn(
for i in T.Parallel(block_M):
if q_block_offset + i < seq_len:
lse_val = T.if_then_else(l_i[i] > 0,
T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
lse_val = T.if_then_else(l_i[i] > 0, T.log(l_i[i]) + m_i[i], -T.infinity(accum_dtype))
LSE[bz, by, q_block_offset + i] = lse_val
bx_loop_var = current_bx + num_split_q
......@@ -234,16 +219,17 @@ def get_bwd_configs():
panel_size = [7, 8, 9, 10]
configs = []
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads,
enable_rasterization, panel_size):
configs.append({
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
})
for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size):
configs.append(
{
"block_M": m,
"block_N": n,
"num_stages": stages,
"threads": t,
"enable_rasterization": r,
"panel_size": p,
}
)
return configs
......@@ -256,8 +242,7 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
blk = 32
@T.prim_func
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
with T.Kernel(batch, heads, T.ceildiv(seq_len, blk)) as (bz, bx, by):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
......@@ -265,21 +250,33 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
@tilelang.autotune(configs=get_bwd_configs(), cache_input_tensors=True)
@tilelang.jit
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, block_N: int,
num_stages: int, threads: int, enable_rasterization: bool, panel_size: int):
sm_scale = (1.0 / dim)**0.5
def flashattn_bwd(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_stages: int,
threads: int,
enable_rasterization: bool,
panel_size: int,
):
sm_scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
......@@ -287,14 +284,17 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
accum_dtype = "float"
@T.prim_func
def flash_bwd_kernel(Q: T.Tensor(q_shape,
dtype), K: T.Tensor(kv_shape,
dtype), V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype), lse: T.Tensor([batch, heads, seq_len],
accum_dtype),
Delta: T.Tensor([batch, heads, seq_len],
accum_dtype), dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype), dV: T.Tensor(kv_shape, accum_dtype)):
def flash_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
dO: T.Tensor(q_shape, dtype),
lse: T.Tensor([batch, heads, seq_len], accum_dtype),
Delta: T.Tensor([batch, heads, seq_len], accum_dtype),
dQ: T.Tensor(q_shape, accum_dtype),
dK: T.Tensor(kv_shape, accum_dtype),
dV: T.Tensor(kv_shape, accum_dtype),
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -315,8 +315,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx // groups, :], V_shared)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
......@@ -324,22 +324,21 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q_shared)
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q_shared)
T.clear(qkT)
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.exp(qkT[i, j] * sm_scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j,
P_acc[i, j], 0.0)
P_acc[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, P_acc[i, j], 0.0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do_shared)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do_shared)
T.clear(dP)
T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -347,7 +346,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, groups, block_M: int, b
T.copy(P_acc, p_cast)
T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta_shared)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared)
for i, j in T.Parallel(block_M, block_N):
p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale
......@@ -378,8 +377,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
def flash_bwd_post(dQ_in: T.Tensor(shape, accum_dtype), dQ_out: T.Tensor(shape, dtype)):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.copy(
dQ_in[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_in[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
......@@ -446,22 +445,14 @@ def benchmark_function(func, *args, warmup=10, repeat=100):
return np.median(times)
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
device = "cuda"
dtype = torch.float16
torch.manual_seed(42)
torch.cuda.manual_seed(42)
print(
f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}"
)
print(f"Test configuration: batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, is_causal={is_causal}, groups={groups}")
flops_per_gemm = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 5 * flops_per_gemm
......@@ -517,22 +508,19 @@ def main(batch: int = 1,
o_ref.backward(dO)
print("Verifying backward pass correctness...")
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(
dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
if dq_close:
print("dQ is correct.")
else:
print("dQ mismatch detected.")
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(
dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
if dk_close:
print("dK is correct.")
else:
print("dK mismatch detected.")
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(
dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
if dv_close:
print("dV is correct.")
else:
......@@ -553,9 +541,7 @@ def main(batch: int = 1,
torch.cuda.synchronize()
ref_latency = benchmark_function(run_reference_fwd_bwd, warmup=10, repeat=100)
print(
f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops"
)
print(f"Reference PyTorch Forward+Backward: {ref_latency:.2f} ms | {total_flops / ref_latency * 1e-9:.2f} TFlops")
def run_complete_fwd_bwd():
o_tl_bench, lse_tl_bench = fwd_kernel(q, k, v)
......@@ -593,12 +579,12 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=1024, help='sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=1024, help="sequence length")
parser.add_argument("--dim", type=int, default=64, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
......@@ -13,10 +13,10 @@ def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, 'shape') and hasattr(param, 'dtype'):
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device='cuda')
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
tensors.append(tensor)
else:
tensors.append(param)
......@@ -24,22 +24,20 @@ def supply_tensors_gpu(params):
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
......@@ -58,23 +56,23 @@ def get_configs():
valid_configs = []
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(
block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width
):
valid_configs.append(
{
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
}
)
return valid_configs
......@@ -98,7 +96,7 @@ def fast_flashattn(
qk_coalesced_width: int,
v_coalesced_width: int,
):
scale = (1.0 / dim)**0.5
scale = (1.0 / dim) ** 0.5
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
......@@ -110,10 +108,10 @@ def fast_flashattn(
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(panel_size, enable=enable_rasterization)
......@@ -147,32 +145,21 @@ def fast_flashattn(
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)
T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)
T.copy(Q[bz, q_block_offset : q_block_offset + block_M, by, :], Q_shared, coalesced_width=vec_size)
loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)
loop_end_k = T.ceildiv(q_block_offset + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
row_sum = T.alloc_fragment([block_M], accum_dtype)
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N
T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)
T.copy(K[bz, kv_idx : kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
T.copy(V[bz, kv_idx : kv_idx + block_N, by // groups, :], V_shared, coalesced_width=v_vec_size)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
......@@ -222,13 +209,7 @@ def fast_flashattn(
return main
def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):
def main(batch: int = 1, heads: int = 8, seq_len: int = 4096, dim: int = 128, is_causal: bool = False, groups: int = 1):
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
......@@ -250,18 +231,16 @@ def main(batch: int = 1,
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=100)
print(
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)
print(f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=8, help="heads")
parser.add_argument("--seq_len", type=int, default=4096, help="sequence length")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--is_causal", action="store_true", help="causal")
parser.add_argument("--groups", type=int, default=1, help="groups")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
......@@ -25,22 +25,7 @@ def check_hopper():
return False
def kernel(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -50,13 +35,11 @@ def kernel(N,
@T.prim_func
def conv(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -65,11 +48,13 @@ def kernel(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
})
T.annotate_layout(
{
out_shared: make_swizzled_layout(out_shared),
data_shared: make_swizzled_layout(data_shared),
kernel_shared: make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -81,10 +66,8 @@ def kernel(N,
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i, j] = T.if_then_else(
in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
......
......@@ -20,9 +20,9 @@ def kernel(
@T.prim_func
def matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......
......@@ -51,8 +51,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
......@@ -120,7 +119,8 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
start_q=seq_kv - seq_q,
)
return o
......@@ -137,12 +137,11 @@ def main(
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -170,15 +169,14 @@ def main(
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
triton_program(Q, K, V, sinks, window_size), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
......@@ -198,20 +196,14 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
......@@ -50,8 +50,7 @@ def triton_kernel(
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M - BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
......@@ -117,26 +116,28 @@ def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.T
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
start_q=seq_kv - seq_q,
)
return o
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -163,15 +164,14 @@ def main(batch: int = 1,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
......@@ -184,19 +184,13 @@ def main(batch: int = 1,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument("--seq_q", type=int, default=4096, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -23,9 +23,11 @@ def get_configs():
rep=100,
)
@tilelang.jit(
out_idx=[3], pass_configs={
out_idx=[3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(
batch,
heads,
......@@ -41,12 +43,11 @@ def flashattn(
threads=256,
dtype: str = "float16",
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
if sm_scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // groups
......@@ -68,13 +69,12 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by // groups, k * block_N : (k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0, -T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -89,18 +89,18 @@ def flashattn(
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by // groups, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -112,8 +112,7 @@ def flashattn(
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
......@@ -128,19 +127,19 @@ def flashattn(
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
......@@ -157,58 +156,58 @@ def flashattn(
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.annotate_layout({
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
})
T.annotate_layout(
{
Q_shared: make_swizzled_layout(Q_shared),
K_shared: make_swizzled_layout(K_shared),
V_shared: make_swizzled_layout(V_shared),
O_shared: make_swizzled_layout(O_shared),
}
)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
end = T.min(T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
start = T.max(0, (bx * block_M + past_len - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]],
):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16) -> torch.Tensor:
def ref_program(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: Optional[int] = None,
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
......@@ -244,23 +243,15 @@ def ref_program(query: torch.Tensor,
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(dtype)
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups, head_dim).to(dtype)
return output.transpose(1, 2).contiguous()
def gen_inputs(
B,
H,
Sq,
Skv,
D,
groups,
dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device='cuda')
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device='cuda')
sinks = torch.randn([H], dtype=dtype, device='cuda')
def gen_inputs(B, H, Sq, Skv, D, groups, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=dtype, device="cuda")
key = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
value = torch.randn([B, H // groups, Skv, D], dtype=dtype, device="cuda")
sinks = torch.randn([H], dtype=dtype, device="cuda")
return query, key, value, sinks
......@@ -277,12 +268,11 @@ def main(
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
print("Using sliding window attention.")
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
flops_per_matmul = 2.0 * batch * heads * min(window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
print("Using full attention.")
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
......@@ -310,15 +300,14 @@ def main(
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
dtype=dtype,
)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), rtol=1e-2, atol=1e-2
)
print("All checks passed.✅")
# Benchmark tilelang
......@@ -329,20 +318,14 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument("--heads", type=int, default=64, help="heads")
parser.add_argument("--seq_q", type=int, default=2048, help="sequence length of query")
parser.add_argument("--seq_kv", type=int, default=2048, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size, args.dtype, args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment