"vscode:/vscode.git/clone" did not exist on "9392125edad18464f1f267eb99bc4036c38f7ed0"
Unverified Commit d9a171ce authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Example] Add examples to support efficient attention sink forward process (#853)



* [Example] Add a new example to support attention sink for MHA

- Introduced a new example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Added a reference attention function to validate the implementation against PyTorch.
- Included argument parsing for command-line execution of the example.

* [Example] Replace MHA sink forward example with updated implementation

- Removed the old example script for multi-head attention (MHA) with sliding window attention and sink tokens.
- Introduced a new example script that modifies the attention mechanism to enhance performance and maintainability.
- Updated argument parsing and reference functions to align with the new implementation.

* Enhance MHA sink example with sliding window support

- Added a `window_size` parameter to the `flashattn` function to enable sliding window attention.
- Implemented assertions to ensure `window_size` is compatible with `block_N`.
- Updated the main function to include a `tune` option for performance tuning.
- Introduced a new test file to validate both full attention and sliding window scenarios.
- Adjusted FLOPS calculation to account for the sliding window configuration.

* lint

* [Fix] Add checkinf process to fix the bug of swa

* Migrate to BSHD layout to align with triton baselines

* lint

* fix typo

* Refactor MHA sink example to use seq_q and seq_kv parameters to accommodate the new sequence length parameters.

* Add GQA sink example for optimized attention mechanism & lint fix

* fix several typos and bugs

* lint

* fix speed issues of swa

* Update examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent b4483090
# Modified from tilelang/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
import itertools
import argparse
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=500,
rep=100,
)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups=1,
window_size=None, # None for full attention
block_M=128,
block_N=128,
num_stages=2,
threads=256,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(
start[0],
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor:
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
query = query.transpose(1, 2).contiguous()
query = query.view(batch_size, query.shape[1], num_key_value_heads, -1, head_dim)
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16)
return output.transpose(1, 2).contiguous()
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def gen_inputs(B, H, Sq, Skv, D,
groups) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda')
key = torch.randn([B, H // groups, Skv, D], dtype=torch.float16, device='cuda')
value = torch.randn([B, H // groups, Skv, D], dtype=torch.float16, device='cuda')
sinks = torch.randn([H], dtype=torch.float16, device='cuda')
return query, key, value, sinks
def main(
batch: int = 1,
heads: int = 64,
seq_q: int = 4096,
seq_kv: int = 4096,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
tune: bool = False,
):
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
# Benchmark tilelang
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.tune)
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd.py
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
import itertools
import argparse
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
block_M=64,
block_N=64,
num_stages=1,
threads=128):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Modified from https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function's interface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16)
return output.transpose(1, 2).contiguous()
def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda')
sinks = torch.zeros([H], dtype=torch.float16, device='cuda')
return query, key, value, sinks
def main(batch: int = 8,
heads: int = 32,
seq_q: int = 4096,
seq_kv: int = 4096,
dim: int = 128,
window_size: int | None = None,
tune: bool = False):
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
latency = do_bench(lambda: ref_program(Q, K, V, sinks, window_size), warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.tune)
# Modified from tilelang/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
# Optimized for Hopper architecture, with a benchmark to compare with official Triton impl
import torch
import tilelang
from tilelang.autotuner import autotune
from tilelang.profiler import do_bench
import tilelang.language as T
import itertools
import argparse
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[0, 1, 2], threads=[128, 256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size=None, # None for full attention
block_M=128,
block_N=128,
num_stages=2,
threads=256):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
if window_size is not None:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx and q_idx < k_idx + window_size, 0,
-T.infinity(acc_s.dtype))
else:
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
for i in T.Parallel(block_M):
if window_size is not None:
scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0,
scores_max[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
Sinks: T.Tensor([heads], dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
sinks = T.alloc_fragment([block_M], dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(
start[0],
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i in T.Parallel(block_M):
logsum[i] += T.exp2(sinks[i] * 1.44269504 -
scores_max[i] * scale) # The only change for attention sink
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
# Following functions are adapted and optimized from
# https://github.com/openai/gpt-oss/blob/main/gpt_oss/triton/attention.py
def ref_program(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sinks: torch.Tensor,
sliding_window: int | None = None) -> torch.Tensor:
query = query.transpose(1, 2).contiguous().unsqueeze(
3) # align with the original function'sinterface
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
start_q = num_keys - num_queries
sm_scale: float = 1.0 / head_dim**0.5
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
key = key.unsqueeze(3)
value = value.unsqueeze(3)
pos_keys = torch.arange(num_keys, device=query.device)
pos_queries = torch.arange(num_queries, device=query.device) + start_q
mask = pos_keys[None, :] > pos_queries[:, None]
mask = mask.float().masked_fill(mask, float("-inf"))
if sliding_window:
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
mask.masked_fill_(too_old, float("-inf"))
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
logits = logits + mask[None, None, None, :, :]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(sinks, logits_max)
sinks = torch.exp(sinks - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups,
head_dim).to(torch.float16)
return output.transpose(1, 2).contiguous()
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
BLOCK_N = 64
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def gen_inputs(B, H, Sq, Skv, D) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn([B, H, Sq, D], dtype=torch.float16, device='cuda')
key = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda')
value = torch.randn([B, H, Skv, D], dtype=torch.float16, device='cuda')
sinks = torch.randn([H], dtype=torch.float16, device='cuda')
return query, key, value, sinks
def main(batch: int = 8,
heads: int = 32,
seq_q: int = 4096,
seq_kv: int = 4096,
dim: int = 128,
window_size: int | None = None,
tune: bool = False):
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim)
torch.testing.assert_close(
kernel(Q, K, V, sinks), ref_program(Q, K, V, sinks, window_size), rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.tune)
import tilelang.testing
import example_mha_sink_fwd_bhsd
import example_mha_sink_fwd_bhsd_wgmma_pipelined
import example_gqa_sink_fwd_bhsd_wgmma_pipelined
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_full_attn():
example_mha_sink_fwd_bhsd.main()
@tilelang.testing.requires_cuda
def test_example_mha_sink_fwd_bhsd_sliding_window():
example_mha_sink_fwd_bhsd.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window():
example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment