Unverified Commit 8f001e02 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[BugFix] Phaseout dependency of Triton in sink examples to make CI happy (#1045)



* [BugFix] Phaseout dependency of Triton in sink examples to make CI happy

- Added `benchmark_gqa_sink_fwd.py` and `benchmark_mha_sink_fwd.py` to evaluate performance of GQA and MHA attention mechanisms using Triton.
- Refactored existing attention sink implementations to remove Triton kernel definitions from the reference programs, streamlining the code.
- Updated input generation and benchmarking logic to enhance configurability and performance measurement.
- Improved overall structure and organization of the examples for better clarity and usability.

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8ce27782
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def main(
batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, groups, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
groups,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype)
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=64, help='heads')
parser.add_argument('--seq_q', type=int, default=2048, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=2048, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.groups, args.window_size,
args.dtype, args.tune)
import torch
import argparse
from tilelang.profiler import do_bench
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
BLOCK_N = 64
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def main(batch: int = 1,
heads: int = 32,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
print('Using sliding window attention.')
assert window_size <= seq_q
flops_per_matmul = 2.0 * batch * heads * min(
window_size, seq_kv // 2) * seq_q * dim # just a rough estimation
else:
print('Using full attention.')
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim * 0.5
total_flops = 2 * flops_per_matmul
if tune:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, window_size, dtype=dtype)
print(f"Best latency: {kernel.latency}")
print(f"Best TFlops: {total_flops / kernel.latency * 1e-9}")
print(f"Best config: {kernel.config}")
else:
block_M = 128
block_N = 128
num_stages = 2
threads = 256
print(f"{block_M=}, {block_N=}, {num_stages=}, {threads=}")
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
window_size,
block_M=block_M,
block_N=block_N,
num_stages=num_stages,
threads=threads,
dtype=dtype)
Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype)
torch.testing.assert_close(
kernel(Q, K, V, sinks),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2)
print("All checks passed.✅")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_q', type=int, default=4096, help='sequence length of query')
parser.add_argument('--seq_kv', type=int, default=4096, help='sequence length of key/value')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument(
'--window_size',
type=int,
default=None,
help='window size (default: None, which means full attention)')
parser.add_argument(
'--dtype', type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument('--tune', action='store_true', help='tune')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype,
args.tune)
......@@ -9,9 +9,6 @@ import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Optional
......@@ -255,122 +252,6 @@ def ref_program(query: torch.Tensor,
return output.transpose(1, 2).contiguous()
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
groups: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h // groups, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
BLOCK_N = 64
groups = n_heads // n_heads_kv
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
groups=groups,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def gen_inputs(
B,
H,
......@@ -443,27 +324,11 @@ def main(
atol=1e-2)
print("All checks passed.✅")
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
# Benchmark triton
latency_triton = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency_triton))
print("Triton: {:.2f} TFlops".format(total_flops / latency_triton * 1e-9))
# Benchmark tilelang
latency_tilelang = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency_tilelang))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9))
print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
......@@ -9,9 +9,6 @@ import tilelang.language as T
from tilelang.layout import make_swizzled_layout
import itertools
import argparse
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Optional
......@@ -249,119 +246,6 @@ def ref_program(query: torch.Tensor,
return output.transpose(1, 2).contiguous()
@triton.jit
def triton_kernel(
Q,
K,
V,
Sinks,
sm_scale,
Out,
Z,
H,
N_Q_CTX,
N_KV_CTX,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BANDWIDTH: tl.constexpr,
start_q: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# load attention sinks
if Sinks is not None: # noqa: SIM108
sink = tl.load(Sinks + off_h).to(tl.float32)
else:
sink = 0
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])
if BANDWIDTH:
lo, hi = tl.maximum(0, start_q + start_m * BLOCK_M -
BANDWIDTH), start_q + (start_m + 1) * BLOCK_M
else:
lo, hi = 0, start_q + (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
if BANDWIDTH:
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
p = tl.math.exp(qk)
alpha = tl.math.exp(m_i - m_ij)
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
# v = v.to(tl.float32)
p = p.to(v.dtype) # We perform fp16 gemm to utilize tensor core
acc = tl.dot(p, v, acc, allow_tf32=False)
l_i = l_i * alpha + l_ij
m_i = m_ij
sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
# m_i += tl.math.log(l_i)
# m_ptrs = M + off_hz * N_Q_CTX + offs_m
# tl.store(m_ptrs, m_i)
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
BLOCK_N = 64
o = torch.empty_like(Q)
grid = (triton.cdiv(seq_q, BLOCK_M), bs * n_heads, 1)
triton_kernel[grid](
TensorDescriptor.from_tensor(Q, [1, 1, BLOCK_M, head_dim]),
TensorDescriptor.from_tensor(K, [1, 1, BLOCK_N, head_dim]),
TensorDescriptor.from_tensor(V, [1, 1, BLOCK_N, head_dim]),
Sinks,
1.0 / head_dim**0.5,
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, head_dim]),
bs,
n_heads,
N_Q_CTX=seq_q,
N_KV_CTX=seq_kv,
HEAD_DIM=head_dim,
BANDWIDTH=window_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
start_q=seq_kv - seq_q)
return o
def gen_inputs(
B,
H,
......@@ -429,18 +313,6 @@ def main(batch: int = 1,
atol=1e-2)
print("All checks passed.✅")
if torch.allclose(
triton_program(Q, K, V, sinks, window_size),
ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype),
rtol=1e-2,
atol=1e-2):
print("Checks for triton passed.✅")
else:
print("Checks for triton failed.❌")
latency = do_bench(lambda: triton_program(Q, K, V, sinks, window_size), warmup=500)
print("Triton: {:.2f} ms".format(latency))
print("Triton: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment