"vscode:/vscode.git/clone" did not exist on "429dc97501f64c7fbd51c8ada8b149dcb841ffef"
Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m):
dtype = T.float
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def per_token_cast(
X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
row_g_id = by
y_local = T.alloc_fragment((blk_m, group_size), dtype)
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
T.annotate_layout(
{
y_local: T.Fragment(y_local.shape, forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
}
)
T.copy(X[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size], y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = y_amax_local[i] / fp8_max
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(y_q_local_fp8, X_fp8[row * blk_m : (row + 1) * blk_m, row_g_id * group_size : (row_g_id + 1) * group_size])
return per_token_cast
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# this function don't support cpu tensor
assert x.dim() == 2
m, n = x.shape
new_n = ceil_div(n, 128) * 128
x_padded = torch.nn.functional.pad(x, (0, new_n - n))
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
def main(M=8192, N=8192, blk_m=8):
kernel = per_token_cast_to_fp8(M, N, blk_m)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
x_fp8, x_amax = kernel(x)
x_fp8_ref, x_amax_ref = ref_program(x)
print("x_fp8:", x_fp8, x_fp8.shape)
print("x_amax:", x_amax, x_amax.shape)
print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape)
print("x_amax_ref:", x_amax_ref, x_amax_ref.shape)
torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench()
print("Tile-lang: {:.2f} ms".format(latency))
from tilelang.profiler import do_bench
from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton()
latency = do_bench(run_triton)
print("Triton: {:.2f} ms".format(latency))
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert x.shape[-1] % group_size == 0, f"the last dimension of `x` {x.shape[-1]} must be divisible by `group_size` {group_size}"
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
import tilelang.testing
import example_group_per_split_token_cast_to_fp8
import example_per_token_cast_to_fp8
def test_example_group_per_split_token_cast_to_fp8():
example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
def test_example_per_token_cast_to_fp8():
example_per_token_cast_to_fp8.main(M=2048, N=512, blk_m=8)
if __name__ == "__main__":
tilelang.testing.main()
import os
import random
import pytest
os.environ["PYTHONHASHSEED"] = "0"
random.seed(0)
try:
import torch
except ImportError:
pass
else:
torch.manual_seed(0)
try:
import numpy as np
except ImportError:
pass
else:
np.random.seed(0)
def pytest_terminal_summary(terminalreporter, exitstatus, config):
"""Ensure that at least one test is collected. Error out if all tests are skipped."""
known_types = {
"failed",
"passed",
"skipped",
"deselected",
"xfailed",
"xpassed",
"warnings",
"error",
}
if sum(len(terminalreporter.stats.get(k, [])) for k in known_types.difference({"skipped", "deselected"})) == 0:
terminalreporter.write_sep(
"!",
(f"Error: No tests were collected. {dict(sorted((k, len(v)) for k, v in terminalreporter.stats.items()))}"),
)
pytest.exit("No tests were collected.", returncode=5)
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
@tilelang.jit(out_idx=[2])
def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return main
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
block_m = 64
block_n = 128
block_k = 32
num_stages = 3
threads = 256
kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads)
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if __name__ == "__main__":
main()
import torch
import argparse
import itertools
import tilelang
import tilelang.language as T
def check_hopper():
if not torch.cuda.is_available():
return None
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
)
)
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
}
for c in _configs
]
return configs
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
elif sm_version in {90}:
return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
else:
return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2])
def convolution(
N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=thread_num) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
if is_hopper:
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
if is_hopper:
T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
else:
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = (access_h >= 0) and (access_w >= 0) and (access_h < H) and (access_w < W)
data_shared[i, j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
if is_hopper:
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
else:
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main
def main(
n: int = 128,
c: int = 128,
h: int = 64,
w: int = 64,
f: int = 128,
k: int = 3,
s: int = 1,
d: int = 1,
p: int = 1,
use_autotune: bool = False,
with_roller: bool = True,
):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
if use_autotune:
kernel = convolution(N, C, H, W, F, K, S, D, P)
else:
config = get_heuristic_config()
kernel = convolution(N, C, H, W, F, K, S, D, P, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_prog)
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
parser.add_argument("--with_roller", action="store_true", default=True, help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune, args.with_roller)
import tilelang.testing
import example_convolution
import example_convolution_autotune
# TODO(@cy): TMA with convolution must be fixed in future.
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_convolution():
example_convolution.main([])
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_convolution_autotune():
example_convolution_autotune.main()
if __name__ == "__main__":
tilelang.testing.main()
from typing import Tuple
import torch
import tilelang.testing
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(42)
@tilelang.jit
def tl_gemm(
M,
N,
K,
block_N,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
T.float8_e4m3fn,
], "Currently only float8_e4m3 is supported"
assert out_dtype in [
T.bfloat16,
T.float32,
], "Currently only float16 and float32 are supported"
group_size = 128
block_M = 128
block_K = 128
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, group_size))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, T.float32),
scales_b: T.Tensor(Scales_B_shape, T.float32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), T.float32)
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def ceildiv(a, b):
return (a + b - 1) // b
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
# A_scale: (M, K//128) ==> (M//128, K//128, 128)
# B_scale: (N//128, K//128) ==> (N//128, K//128, 128)
# A_fp8: (M, K)
# B_fp8: (N, K)
# out_dtype: float16 or float32
# return C: (M, N)
M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
for i in range(ceildiv(M, 128)):
for j in range(ceildiv(N, 128)):
c_acc.zero_()
for k in range(ceildiv(K, 128)):
c = torch._scaled_mm(
A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16,
)
c_acc += c.to(torch.float32)
C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
return C
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
A = torch.randn(M, K).to(torch.bfloat16).cuda()
B = torch.randn(N, K).to(torch.bfloat16).cuda()
A_fp8, A_scale = per_token_cast_to_fp8(A.clone())
B_fp8, B_scale = per_block_cast_to_fp8(B.clone())
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
kernel(A_fp8, B_fp8, C, A_scale, B_scale)
# Get Reference Result
ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
diff = calc_diff(C, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
profiler = kernel.get_profiler()
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
print(f"latency: {latency} ms")
tflops = 2 * M * N * K / latency / 1e9
print(f"tflops: {tflops}")
def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32)
if __name__ == "__main__":
for dtype in [T.float8_e4m3fn]:
for out_dtype in [T.bfloat16, T.float32]:
for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32)
import tilelang.testing
from example_deepgemm_fp8_2xAcc import main
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_deepgemm_fp8_2xAcc():
main()
if __name__ == "__main__":
tilelang.testing.main()
# 🚀 How to write high-performance kernel with TileLang: take MLA as an example
TileLang is a user-friendly AI programming language that significantly lowers the barrier to kernel programming, helping users quickly build customized operators. However, users still need to master certain programming techniques to better leverage TileLang's powerful capabilities. Here, we'll use MLA as an example to demonstrate how to write high-performance kernels with TileLang.
## Introduction to MLA
DeepSeek's MLA (Multi-Head Latent Attention) is a novel attention mechanism known for its hardware efficiency and significant improvements in model inference speed. Several deep learning compilers (such as [Triton](https://github.com/triton-lang/triton)) and libraries (such as [FlashInfer](https://github.com/flashinfer-ai/flashinfer)) have developed their own implementations of MLA. In February 2025, [FlashMLA](https://github.com/deepseek-ai/FlashMLA) was open-sourced on GitHub. FlashMLA utilizes [CUTLASS](https://github.com/NVIDIA/cutlass) templates and incorporates optimization techniques from [FlashAttention](https://github.com/Dao-AILab/flash-attention), achieving impressive performance.
## Benchmark Results
We benchmarked the performance of FlashMLA, TileLang, Torch, Triton, and FlashInfer under batch sizes of 64 and 128, with float16 data type, as shown in the figures below.
<figure style="text-align: center">
<a href="./figures/bs64_float16.png">
<img src="./figures/bs64_float16.png" alt="bs64_float16">
</a>
<figcaption style="text-align: center;">Figure 1:Performance under batch size=64</figcaption>
</figure>
<figure style="text-align: center">
<a href="./figures/bs128_float16.png">
<img src="./figures/bs128_float16.png" alt="bs128_float16">
</a>
<figcaption style="text-align: center;">Figure 2:Performance under batch size=128</figcaption>
</figure>
As shown in the results, TileLang achieves performance comparable to FlashMLA in most cases, significantly outperforming both FlashInfer and Triton.
Notably, **TileLang accomplishes this with just around 80 lines of Python code**, demonstrating its exceptional ease of use and efficiency. Let's dive in and see how TileLang achieves this.
## Implementation
First, let's review the core computation logic of traditional FlashAttention:
```python
# acc_s: [block_M, block_N]
# scores_max: [block_M]
# scores_scale: [block_M]
# acc_o: [block_M, dim]
for i in range(loop_range):
acc_s = Q @ K[i]
scores_max_prev = scores_max
scores_max = max(acc_s, dim=1)
scores_scale = exp(scores_max_prev - scores_max)
acc_o *= scores_scale
acc_s = exp(acc_s - scores_max)
acc_o = acc_s @ V[i]
...
```
Here, `acc_s` represents the `Q @ K` result in each iteration with dimensions `[block_M, block_N]`, while `acc_o` represents the current iteration's output with dimensions `[block_M, dim]`. Both `acc_s` and `acc_o` need to be stored in registers to reduce latency.
Compared to traditional attention operators like MHA (Multi-Headed Attention) or GQA (Grouped Query Attention), a major challenge in optimizing MLA is its large head dimensions - `query` and `key` have head dimensions of 576 (512 + 64), while `value` has a head dimension of 512. This raises a significant issue: `acc_o` becomes too large, and with insufficient threads (e.g., 128 threads), register spilling occurs, severely impacting performance.
This raises the question of how to partition the matrix multiplication operation. On the Hopper architecture, most computation kernels use [`wgmma.mma_async`](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions) instructions for optimal performance. The `wgmma.mma_async` instruction organizes 4 warps (128 threads) into a warpgroup for collective MMA operations. However, `wgmma.mma_async` instructions require a minimum M dimension of 64. This means each warpgroup's minimum M dimension can only be reduced to 64, but a tile size of 64*512 is too large for a single warpgroup, leading to register spilling.
Therefore, our only option is to partition `acc_o` along the `dim` dimension, with two warpgroups computing the left and right part of `acc_o` respectively. However, this introduces another challenge: both warpgroups require the complete `acc_s` result as input.
Our solution is to have each warpgroup compute half of `acc_s` during `Q @ K` computation, then obtain the other half computed by the other warpgroup through shared memory.
### Layout Inference
While the above process may seem complex, but don't worry - TileLang will handle all these intricacies for you.
Figure 3 and Figure 4 illustrate the frontend TileLang script and its corresponding execution plan for MLA. Here, `T.gemm` represents matrix multiplication operations, `transpose_B=True` indicates transposition of matrix B, and `policy=FullCol` specifies that each warpgroup computes one column (e.g., split the result matrix in vertical dimension). `T.copy` represents buffer-to-buffer copying operations.
<figure style="text-align: center">
<a href="./figures/qk_layout.jpg">
<img src="./figures/qk_layout.jpg" alt="QK Layout">
</a>
<figcaption style="text-align: center;">Figure 3:Buffer shapes in Q @ K</figcaption>
</figure>
<figure style="text-align: center">
<a href="./figures/pv_layout.jpg">
<img src="./figures/pv_layout.jpg" alt="PV Layout">
</a>
<figcaption style="text-align: center;">Figure 4:Buffer shapes in acc_s @ V</figcaption>
</figure>
The mapping from TileLang frontend code to execution plan is accomplished through Layout Inference. Layout inference is a core optimization technique in TileLang. It automatically deduces the required buffer shapes and optimal layouts based on Tile-Operators (like `T.gemm`, `T.copy`, etc.), then generates the corresponding code. Here, we demonstrate a concrete example of buffer shape inference in MLA.
For instance, when computing `Q @ K`, TileLang infers that each warpgroup's `acc_s_0` shape should be `[blockM, blockN / 2]` based on the `policy=FullCol` annotation in `T.gemm`. Since this is followed by an `acc_s @ V` operation with `policy=FullCol`, which requires each warpgroup to have the complete `acc_s` result, TileLang deduces that `acc_s`'s shape at this point should be `[blockM, blockN]`. Consequently, TileLang can continue the inference process forward, determining that both `S_shared` and `acc_s` in `T.copy(S_shared, acc_s)` should have shapes of `[blockM, blockN]`.
It's worth noting that our scheduling approach differs from FlashMLA's implementation strategy. In FlashMLA, `Q @ K` is assigned to a single warpgroup, while the `acc_o` partitioning scheme remains consistent with ours. Nevertheless, our scheduling approach still achieves comparable performance.
### Threadblock Swizzling
Threadblock swizzling is a common performance optimization technique in GPU kernel optimization. In GPU architecture, the L2 cache is a high-speed cache shared among multiple SMs (Streaming Multiprocessors). Threadblock swizzling optimizes data access patterns by remapping the scheduling order of threadblocks, thereby improving L2 cache hit rates. Traditional scheduling typically executes threadblocks in the natural order of the grid, which can lead to non-contiguous data access patterns between adjacent threadblocks, resulting in inefficient utilization of cached data. The swizzle technique employs mathematical mapping methods (such as diagonal or interleaved mapping) to adjust the execution order of threadblocks, ensuring that consecutively scheduled threadblocks access adjacent or overlapping data regions.
In TileLang, threadblock swizzling optimization can be implemented with just a single line of Python code:
```python
T.use_swizzle(panel_size: int, order: str = "row")
```
Here, `panel_size` specifies the width of the swizzled threadblock group, and `order` determines the swizzling pattern, which can be either "row" or "col".
### Shared Memory Swizzling
In CUDA programming, shared memory is divided into multiple memory banks, with each bank capable of servicing one thread request per clock cycle in parallel. Bank conflicts occur when multiple threads simultaneously access different addresses mapped to the same bank, forcing these accesses to be serialized and degrading performance.
One common strategy to address bank conflicts is shared memory swizzling. This technique rearranges how data is stored in shared memory by remapping addresses that would originally fall into the same bank to different banks, thereby reducing conflicts. For example, XOR operations or other bit manipulations can be incorporated into address calculations to alter the data layout, resulting in more evenly distributed memory accesses across consecutive threads. This approach is particularly crucial for implementing high-performance computing tasks like matrix multiplication and convolution, as it can significantly improve memory access parallelism and overall execution efficiency.
Similarly, TileLang also supports shared memory swizzling. Users only need to add a single line of Python code:
```python
T.annotate_layout({
S_shared: TileLang.layout.make_swizzled_layout(S_shared),
})
```
Here, `T.annotate_layout` allows users to specify any desired layout for a buffer. For convenience, TileLang provides the `make_swizzled_layout` primitive to automatically generate a swizzled layout.
### Warp-Specialization
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
### Pipeline
Pipeline is a technique used to improve memory access efficiency by overlapping memory access and computation. In TileLang, pipeline can be implemented through the `T.pipelined` annotation:
```python
T.pipelined(range: int, stage: int)
```
Here, `range` specifies the range of the pipeline, and `stage` specifies the stage of the pipeline. Multi-stage pipelining enables overlapping of computation and memory access, which can significantly improve performance for memory-intensive operators. However, setting a higher number of stages consumes more shared memory resources, so the optimal configuration needs to be determined based on specific use cases.
### Split-KV
We have also implemented Split-KV optimization similar to [FlashDecoding](https://pytorch.org/blog/flash-decoding/). Specifically, when the batch size is small, parallel SM resources cannot be fully utilized due to low parallelism. In such cases, we can split the kv_ctx dimension across multiple SMs for parallel computation and then merge the results.
In our implementation, we have developed both split and combine kernels, allowing users to control the split size through a `num_split` parameter.
\ No newline at end of file
# 🚀 High-Performance FlashMLA Implementation Using TileLang on AMD MI300X Accelerators
Following our previous demonstration of [high-performance FlashMLA implementation on NVIDIA Hopper architectures using TileLang](https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/README.md), this work presents an optimized implementation for AMD MI300X accelerators. We examine architectural differences and corresponding optimization strategies between these platforms.
## Architectural Considerations and Optimization Strategies
Key implementation differences between Hopper and MI300X architectures include:
1. **Instruction Set Variations**: The MI300X architecture eliminates the need for explicit Tensor Memory Access (TMA) instructions and warp specialization, which are automatically handled by the compiler on Hopper architectures, resulting in identical source code manifestations.
2. **Shared Memory Constraints**: With 64KB of shared memory compared to Hopper's 228KB, MI300X implementations require careful memory management. Our optimization strategy includes:
- Reducing software pipeline stages
- Register-based caching of Q matrices instead of shared memory utilization:
```python
# Original shared memory allocation
Q_shared = T.alloc_shared([block_H, dim], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
# Optimized register allocation
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
```
3. **Tile Size Flexibility**: The absence of WGMMA instructions on MI300X permits more flexible tile size selection, removing the requirement for block_m to be multiples of 64.
4. **Memory Bank Conflict Swizzling**: MI300x has different memory bank conflict rules compared to NVIDIA, so we need to use different swizzling strategies. This is also automatically handled by TileLang, resulting in no visible differences in the code.
## Performance Evaluation
We conducted comparative performance analysis across multiple frameworks using float16 precision with batch sizes 64 and 128. The experimental results demonstrate:
<figure style="text-align: center">
<a href="../figures/flashmla-amd.png">
<img src="../figures/flashmla-amd.png" alt="AMD FlashMLA Performance Comparison">
</a>
<figcaption style="text-align: center;">Figure 1: Computational throughput comparison across frameworks (Batch sizes 64 and 128)</figcaption>
</figure>
Notably, TileLang achieves performance parity with hand-optimized assembly kernels (aiter-asm) (from 0.73x to 1.21x) in most test cases, while significantly outperforming Triton (up to 6.5x faster)implementations. This performance is achieved through a concise 70-line Python implementation!
## Future Optimization Opportunities
1. **Memory Bank Conflict Mitigation**: Current implementations primarily address bank conflicts in NT layouts through TileLang's automatic optimization. Further investigation of swizzling techniques for alternative memory layouts remains an open research direction.
2. **Dimension Parallelization**: For large MLA dimensions (e.g., 576 elements), we propose investigating head dimension partitioning strategies to:
- Reduce shared memory pressure
- Improve compute-to-memory access ratios
- Enhance parallelism through dimension-wise task distribution
## Acknowledgment
We would like to express our sincere gratitude to the AMD ROCm and Composable Kernel team for their outstanding contributions. We have learned a great deal from the ROCm software stack.
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
threads = [128, 256]
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
return [
{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
}
for c in _configs
]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by):
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# T.copy(acc_s, S_shared)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz):
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=0):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
parser.add_argument("--autotune", action="store_true", help="auto tune")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.autotune
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 32
BLOCK_H = 64
num_split = 4
threads = 128
if enable_autotune:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
ref_output = ref_program(*input_tensors)
print(f"Tilelang output: {tilelang_output}")
print(f"Ref output: {ref_output}")
torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py
# ruff: noqa
import argparse
import math
import random
import torch
import triton
import triton.language as tl
import tilelang
from tilelang.profiler import do_bench
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q,
h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_torch, lse_torch = ref_mla()
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@triton.jit
def _mla_attn_kernel(
Q_nope,
Q_pe,
Kv_c_cache,
K_pe_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
stride_q_nope_bs,
stride_q_nope_h,
stride_q_pe_bs,
stride_q_pe_h,
stride_kv_c_bs,
stride_k_pe_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
HEAD_DIM_KPE: tl.constexpr,
):
cur_batch = tl.program_id(1)
cur_head_id = tl.program_id(0)
split_kv_id = tl.program_id(2)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
q_pe = tl.load(Q_pe + offs_q_pe)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
qk *= sm_scale
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
v_c = tl.trans(k_c)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v_c.dtype), v_c)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
def _mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
):
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]
BLOCK_H = 16
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
num_kv_splits,
)
_mla_attn_kernel[grid](
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
req_to_tokens,
b_seq_len,
attn_logits,
sm_scale,
# stride
q_nope.stride(0),
q_nope.stride(1),
q_pe.stride(0),
q_pe.stride(1),
kv_c_cache.stride(-2),
k_pe_cache.stride(-2),
req_to_tokens.stride(0),
attn_logits.stride(0),
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
num_stages=1, # 2 will oom in amd
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
B_seq_len,
O,
stride_l_b,
stride_l_h,
stride_l_s,
stride_o_b,
stride_o_h,
NUM_KV_SPLITS: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
n_e_max = tl.maximum(logits_1, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(logits_1 - n_e_max)
acc += exp_logic * logits
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
)
def _mla_softmax_reducev(
logits,
o,
b_seq_len,
num_kv_splits,
):
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
grid = (batch_size, head_num)
_mla_softmax_reducev_kernel[grid](
logits,
b_seq_len,
o,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
HEAD_DIM_CKV=head_dim_ckv,
)
def mla_decode_triton(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
o,
req_to_tokens,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
):
assert num_kv_splits == attn_logits.shape[2]
_mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
)
_mla_softmax_reducev(
attn_logits,
o,
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert baseline in FUNC_TABLE
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]:
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}"
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
available_targets = [
"torch",
"flash_mla_triton",
]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 16384]
for head in [128]
]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", type=str, default="torch")
parser.add_argument("--target", type=str, default="torch")
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py
# ruff: noqa
import argparse
import math
import random
import torch
import triton
import triton.language as tl
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q,
h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_torch, lse_torch = ref_mla()
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@triton.jit
def _mla_attn_kernel(
Q_nope,
Q_pe,
Kv_c_cache,
K_pe_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
stride_q_nope_bs,
stride_q_nope_h,
stride_q_pe_bs,
stride_q_pe_h,
stride_kv_c_bs,
stride_k_pe_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
HEAD_DIM_KPE: tl.constexpr,
):
cur_batch = tl.program_id(1)
cur_head_id = tl.program_id(0)
split_kv_id = tl.program_id(2)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
q_pe = tl.load(Q_pe + offs_q_pe)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
qk *= sm_scale
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
v_c = tl.trans(k_c)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v_c.dtype), v_c)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
def _mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
):
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]
BLOCK_H = 16
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
num_kv_splits,
)
_mla_attn_kernel[grid](
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
req_to_tokens,
b_seq_len,
attn_logits,
sm_scale,
# stride
q_nope.stride(0),
q_nope.stride(1),
q_pe.stride(0),
q_pe.stride(1),
kv_c_cache.stride(-2),
k_pe_cache.stride(-2),
req_to_tokens.stride(0),
attn_logits.stride(0),
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
num_stages=1, # 2 will oom in amd
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
B_seq_len,
O,
stride_l_b,
stride_l_h,
stride_l_s,
stride_o_b,
stride_o_h,
NUM_KV_SPLITS: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
n_e_max = tl.maximum(logits_1, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(logits_1 - n_e_max)
acc += exp_logic * logits
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
)
def _mla_softmax_reducev(
logits,
o,
b_seq_len,
num_kv_splits,
):
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
grid = (batch_size, head_num)
_mla_softmax_reducev_kernel[grid](
logits,
b_seq_len,
o,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
HEAD_DIM_CKV=head_dim_ckv,
)
def mla_decode_triton(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
o,
req_to_tokens,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
):
assert num_kv_splits == attn_logits.shape[2]
_mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
)
_mla_softmax_reducev(
attn_logits,
o,
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert baseline in FUNC_TABLE
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]:
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}"
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
available_targets = [
"torch",
"flash_mla_triton",
]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [64, 128]
for seqlen in [1024, 2048, 4096, 8192, 16384]
for head in [128]
]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", type=str, default="torch")
parser.add_argument("--target", type=str, default="flash_mla_triton")
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py
# ruff: noqa
import argparse
import math
import random
import torch
import triton
import triton.language as tl
import tilelang
from tilelang.profiler import do_bench
from example_mla_decode_paged import mla_decode_tilelang
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q,
h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_torch, lse_torch = ref_mla()
t = triton.testing.do_bench(ref_mla)
return out_torch, lse_torch, t
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
out_flash, lse_flash = flash_mla()
t = triton.testing.do_bench(flash_mla)
return out_flash, lse_flash, t
@torch.inference_mode()
def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# pip install flashinfer-python
import flashinfer
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
kv_indptr = [0]
kv_indices = []
for i in range(b):
seq_len = cache_seqlens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_table[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
mla_wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
cache_seqlens,
h_q,
dv,
d - dv,
block_size,
causal,
1 / math.sqrt(d),
q.dtype,
blocked_k.dtype,
)
def flashinfer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flashinfer()
t = triton.testing.do_bench(flashinfer)
return out_flash, lse_flash, t
@triton.jit
def _mla_attn_kernel(
Q_nope,
Q_pe,
Kv_c_cache,
K_pe_cache,
Req_to_tokens,
B_seq_len,
O,
sm_scale,
stride_q_nope_bs,
stride_q_nope_h,
stride_q_pe_bs,
stride_q_pe_h,
stride_kv_c_bs,
stride_k_pe_bs,
stride_req_to_tokens_bs,
stride_o_b,
stride_o_h,
stride_o_s,
BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
HEAD_DIM_KPE: tl.constexpr,
):
cur_batch = tl.program_id(1)
cur_head_id = tl.program_id(0)
split_kv_id = tl.program_id(2)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :]
q_pe = tl.load(Q_pe + offs_q_pe)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None]
k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk = tl.dot(q_nope, k_c.to(q_nope.dtype))
offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None]
k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0)
qk += tl.dot(q_pe, k_pe.to(q_pe.dtype))
qk *= sm_scale
qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf"))
v_c = tl.trans(k_c)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v_c.dtype), v_c)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
def _mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
):
batch_size, head_num = q_nope.shape[0], q_nope.shape[1]
head_dim_ckv = q_nope.shape[-1]
head_dim_kpe = q_pe.shape[-1]
BLOCK_H = 16
BLOCK_N = 64
grid = (
triton.cdiv(head_num, BLOCK_H),
batch_size,
num_kv_splits,
)
_mla_attn_kernel[grid](
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
req_to_tokens,
b_seq_len,
attn_logits,
sm_scale,
# stride
q_nope.stride(0),
q_nope.stride(1),
q_pe.stride(0),
q_pe.stride(1),
kv_c_cache.stride(-2),
k_pe_cache.stride(-2),
req_to_tokens.stride(0),
attn_logits.stride(0),
attn_logits.stride(1),
attn_logits.stride(2),
BLOCK_H=BLOCK_H,
BLOCK_N=BLOCK_N,
NUM_KV_SPLITS=num_kv_splits,
PAGE_SIZE=page_size,
HEAD_DIM_CKV=head_dim_ckv,
HEAD_DIM_KPE=head_dim_kpe,
)
@triton.jit
def _mla_softmax_reducev_kernel(
Logits,
B_seq_len,
O,
stride_l_b,
stride_l_h,
stride_l_s,
stride_o_b,
stride_o_h,
NUM_KV_SPLITS: tl.constexpr,
HEAD_DIM_CKV: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_seq_len + cur_batch)
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32)
offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv
offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s)
logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s)
n_e_max = tl.maximum(logits_1, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(logits_1 - n_e_max)
acc += exp_logic * logits
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv,
acc / e_sum,
)
def _mla_softmax_reducev(
logits,
o,
b_seq_len,
num_kv_splits,
):
batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2]
grid = (batch_size, head_num)
_mla_softmax_reducev_kernel[grid](
logits,
b_seq_len,
o,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
HEAD_DIM_CKV=head_dim_ckv,
num_warps=4,
num_stages=2,
)
def mla_decode_triton(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
o,
req_to_tokens,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
):
assert num_kv_splits == attn_logits.shape[2]
_mla_attn(
q_nope,
q_pe,
kv_c_cache,
k_pe_cache,
attn_logits,
req_to_tokens,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
)
_mla_softmax_reducev(
attn_logits,
o,
b_seq_len,
num_kv_splits,
)
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
t = triton.testing.do_bench(flash_mla_triton)
return out_flash, None, t
@torch.inference_mode()
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
def flash_mla_tilelang():
out = kernel(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
cache_seqlens,
glse,
out_partial,
)
return out.view([b, s_q, h_q, dv])
out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang)
return out_flash, None, t
FUNC_TABLE = {
"torch": run_torch_mla,
"tilelang": run_flash_mla_tilelang,
"flash_mla": run_flash_mla,
"flashinfer": run_flashinfer,
"flash_mla_triton": run_flash_mla_triton,
}
def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert baseline in FUNC_TABLE
assert target in FUNC_TABLE
baseline_func = FUNC_TABLE[baseline]
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
assert target in FUNC_TABLE
target_func = FUNC_TABLE[target]
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
available_targets = [
"torch",
"tilelang",
"flash_mla",
"flashinfer",
"flash_mla_triton",
]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]
for head in [128]
]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", type=str, default="torch")
parser.add_argument("--target", type=str, default="tilelang")
parser.add_argument("--all", action="store_true")
parser.add_argument("--one", action="store_true")
parser.add_argument("--compare", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target
with open(f"{benchmark_type}_perf.csv", "w") as fout:
fout.write("name,batch,seqlen,head,bw\n")
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, hid, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def main(
batch=1,
heads=128,
kv_heads=1,
kv_ctx=8192,
dim=512,
pe_dim=64,
):
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim) ** -0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=132, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
from tilelang.profiler import do_bench
import math
@tilelang.jit(
out_idx=[8],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None):
if softmax_scale is None:
softmax_scale = (dv + dpe) ** -0.5
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1"
assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N"
@T.macro
def flash_mla_kernel(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], T.int32),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
if kr == 0:
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_mla_split_kv_kernel(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype)
o_accum_local = T.alloc_fragment([dv], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dv):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dv):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dv):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
# cache_seqlens: [b]
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q,
h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out.to(dtype), lse.to(dtype)
out_torch, _ = ref_mla()
return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = min(64, h_q // h_kv)
softmax_scale = d**-0.5
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
out = profiler.func(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
cache_seqlens,
glse,
out_partial,
)
return out.view([b, s_q, h_q, dv])
out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close")
return out_flash, t
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--h_q", type=int, default=128, help="q heads number")
parser.add_argument("--h_kv", type=int, default=1, help="kv heads number")
parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length")
parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe")
parser.add_argument("--dv", type=int, default=512, help="value head dim")
args = parser.parse_args()
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
device = "cuda"
dtype = torch.float16
s_q = 1 # for decode, s_q = 1
block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
dpe = d - dv
causal = True
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256
total_flops = s_q * total_seqlens * h_q * d * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from tilelang.carver.arch import driver
from einops import rearrange, einsum
import argparse
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
sm_num = driver.get_num_sms()
@T.prim_func
def main_split_persistent(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(sm_num, threads=256) as (block_id):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
# O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.use_swizzle(10)
total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split
waves = T.ceildiv(total_tiles, sm_num)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
bid = tile_id // ((heads // min(block_H, kv_group_num)) * num_split)
hid = tile_id // num_split % (heads // min(block_H, kv_group_num))
sid = tile_id % num_split
cur_kv_head = hid // (kv_group_num // block_H)
if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split:
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * sid + k * block_N
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid])
# T.copy(acc_o, O_shared)
T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :])
T.sync_grid()
waves = T.ceildiv(heads * batch, sm_num)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
hid = tile_id // batch
bid = tile_id % batch
if bid < batch and hid < heads:
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bid, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bid, hid, k, i]
lse_local_split[0] = glse[bid, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bid, hid, i] = o_accum_local[i]
return main_split_persistent
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
num_split = 2
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
main()
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
@tilelang.jit(
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=[
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
sm_scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid):
Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype)
Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype)
Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype)
K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype)
acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
sumexp = T.alloc_fragment([block_H], accum_dtype)
sum_exp_shared = T.alloc_shared([block_H], accum_dtype)
sumexp_i = T.alloc_fragment([block_H], accum_dtype)
alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([block_H], accum_dtype)
m_i = T.alloc_fragment([block_H], accum_dtype)
m_i_prev = T.alloc_fragment([block_H], accum_dtype)
# TODO: Multi buffer
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
cur_kv_head = hid // (kv_group_num // block_H)
NI = T.ceildiv((seqlen_kv // num_split), block_N)
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
T.clear(acc_s)
T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, out=m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(block_H):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
T.clear(acc_s)
T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(block_H):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(block_H):
sum_exp_shared[h_i] = sumexp[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(block_H):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz):
Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype)
Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype)
Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared_0_l = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_0_r = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_1_l = T.alloc_shared([block_N, dim // 2], dtype)
KV_shared_1_r = T.alloc_shared([block_N, dim // 2], dtype)
K_tail_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
K_tail_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
acc_o_l = T.alloc_fragment([block_H, dim // 2], accum_dtype)
acc_o_r = T.alloc_fragment([block_H, dim // 2], accum_dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
sumexp = T.alloc_fragment([block_H], accum_dtype)
sum_exp_shared = T.alloc_shared([block_H], accum_dtype)
sumexp_i = T.alloc_fragment([block_H], accum_dtype)
alpha_shared = T.alloc_shared([block_H], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([block_H], accum_dtype)
m_i = T.alloc_fragment([block_H], accum_dtype)
m_i_prev = T.alloc_fragment([block_H], accum_dtype)
# TODO: Multi buffer
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
cur_kv_head = hid // (kv_group_num // block_H)
NI = T.ceildiv((seqlen_kv // num_split), block_N)
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
T.clear(acc_s)
T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(block_H):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
T.clear(acc_s)
T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(block_H):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(block_H):
sum_exp_shared[h_i] = sumexp[h_i]
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(block_H):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2])
T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
for h_i, d_i in T.Parallel(block_H, dim // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, hid, k, i]
lse_local_split[0] = glse[bz, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, hid, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def main(
batch=1,
heads=128,
kv_heads=1,
kv_ctx=8192,
dim=512,
pe_dim=64,
):
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim) ** -0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=132, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment