Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[2])
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = T.int8
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.get_thread_binding()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
vj = index % block_K
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
kernel = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = profiler.run_once()
assert out is not None
def ref_program(A, qB):
import torch
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,
)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = T.int8
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_rows = 4
warp_cols = 4
warp_row_tiles = micro_size_x * warp_rows
warp_col_tiles = micro_size_y * warp_cols
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
reduce_k = 1
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == T.float16 else 64
chunk = block_K // reduce_k
is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte,
)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1)
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
}
)
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % (
block_N // micro_size_y
)
B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
rk=rk,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
rk=rk,
)
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern(
"handle",
"decode_i4u_to_f16",
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]),
8,
)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
if reduce_k > 1:
for n in T.serial(warp_rows * warp_cols * local_size):
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_local[n],
True,
reduced_accum_res[0],
rk,
dtype="handle",
)
)
if rk == 0:
C_local[n] = reduced_accum_res[0]
if rk == 0:
mma_emitter.stmatrix(
C_local,
C_shared,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i, bx * block_N + vj] = C_shared[
i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y
]
return main
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
# src_code is the generated cuda source
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = T.int8
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=K,
transform_kind=transform_b,
transpose_matrix=True,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype=in_dtype,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
kernel(A, QLB, C)
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("Ref C: ", ref_c)
print("C: ", C)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3)
def main():
test_run_dequantize_gemm()
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == T.float16
assert val.dtype == T.uint8
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
e_f16 = e_f4 + tir.const(14, T.uint16)
m_f4 = f4 & tir.const(1, T.uint16)
m_f16 = m_f4
val_f16 = tir.reinterpret(
T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16)
)
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f"{val_cpu:032b}"
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 14
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = test_convert(
N,
K,
block_N,
block_K,
T.float16,
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def get_configs():
block_M = [64, 128]
block_N = [64, 128]
block_K = [128, 256]
num_stages = [1, 2]
threads = [128, 256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
KK = K // split
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
for k in range(split):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = T.float16
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if not tune:
kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=256, help="M")
parser.add_argument("--n", type=int, default=256, help="N")
parser.add_argument("--k", type=int, default=256, help="K")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == T.int8
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint8)
i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask
i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8))
i8 = i8_shifted >> tir.const(4, T.int8)
return i8
def get_configs():
iter_params = dict(
block_M=[64, 128],
block_N=[64, 128],
block_K=[128, 256],
num_stages=[1, 2],
threads=[128, 256, 512],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = (val >> (pos * 4)) & mask
i4 = (i4_shifted << 4) >> 4
return i4.view(torch.int8)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def ref_program(A, qB):
dtypeC = T.int32
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_local_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if not tune:
kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")
latency = profiler.do_bench(warmup=50)
print(f"Tilelang: {latency} ms")
else:
best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
print(f"Best config: {best_config}")
print(f"Best tflops: {total_flops / best_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=512, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=512, help="Matrix dimension K")
parser.add_argument("--tune", action="store_true", help="Enable tuning")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
# main(M, N, K, True)
import tilelang
from tilelang import language as T
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,
)
@tilelang.jit
def dequantize_gemv(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: T.dtype = T.int8,
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
fast_decoding: bool = False,
trans_A: bool = False,
trans_B: bool = True,
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
)
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
storage_type = "".join(c for c in storage_dtype if not c.isdigit())
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
block_K = reduce_thread * micro_size_k
if group_size == -1:
group_size = K
A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
import_source: Optional[str] = None
func_name: str = ""
if fast_decoding is True:
# Lazy import to decrease the startup time
# as intrin registry may take a while to load
from tilelang.quantize import get_lop3_intrin_group
lop3_intrin_info = get_lop3_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
with_scaling=with_scaling,
with_zeros=False,
)
import_source = lop3_intrin_info["c_source"]
func_name = lop3_intrin_info["func_name"]
assert import_source is not None, "lop3_intrin_info is not found"
assert func_name is not None, "lop3_intrin_info is not found"
import_source = import_source
@T.prim_func
def main(
A: T.Tensor[A_shape, in_dtype],
B: T.Tensor[B_shape, storage_dtype],
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.import_source(import_source)
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
]
if fast_decoding:
T.call_extern(
func_name,
T.address_of(B_quant_local[0]),
T.address_of(B_dequantize_local[0]),
dtype=in_dtype,
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_dequantize_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
)
)
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return main
def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = T.float16
out_dtype = T.float16
accum_dtype = T.float16
num_bits = 4
storage_dtype = T.int8
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
trans_A = False
trans_B = True
group_size = -1
with_scaling = False
kernel = dequantize_gemv(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
num_bits,
storage_dtype,
source_format,
n_partition,
reduce_thread,
fast_decoding,
trans_A,
trans_B,
group_size,
with_scaling,
)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("C: ", C)
print("Ref C: ", ref_c)
# doesn't apply scaling, the absolute error is large
torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
from tilelang.quantize import _tir_u8_to_f4_to_bf16
from tilelang import tvm as tvm
from tvm import DataType
import torch
from dequantize_utils import torch_convert_bit_twiddling, assert_similar
from tilelang.autotuner import set_autotune_inputs
import argparse
def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes
- num_stages: pipeline stages
- threads: thread counts
- split: K-splitting factor
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[128],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 1, 2],
threads=[128, 256, 512],
split=[1],
)
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def matmul(
M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype` and shape (M, K).
- B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits).
- Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size).
- Bias: per-expert, per-output bias, shape (E, N).
- topk_weights: router weights for the top-k experts for each token, shape (M, topk).
- sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,).
- expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,).
- C: output tensor, shape (M, topk, N).
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split).
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., T.bfloat16).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the grouped, pipelined GEMM that:
- loads tiled blocks of A and packed B for each expert to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- applies per-token topk weights and bias,
- writes the final (M, topk, N) block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = block_N
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
# the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k):
# import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
assert in_dtype in ["fp4"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), T.int32),
expert_ids: T.Tensor((padding_M // block_M), T.int32),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), T.int32)
expert_id = T.alloc_local((1), T.int32) # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.use_swizzle(10)
if threads == 512:
T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block
for i in T.Parallel(block_M):
if sorted_token_ids_shared[i] != -1:
topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]]
# Get bias and scale based on the expert id
if with_bias:
T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared)
else:
T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j]
tx = T.get_thread_binding()
for k in T.Pipelined(K // block_K, num_stages=num_stages):
# Each thread copies 4 bytes, local size is 16
for copy_i in T.serial(block_M * block_K // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K + copy_j] = A[
sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j
]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = C_local[i, j] * topk_weights_shared[i]
T.copy(C_local, C_shared)
for copy_i in T.serial(block_M * block_N // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16):
C[
sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk,
bx * block_N + base % block_N + copy_j,
] = C_shared[base // block_N, base % block_N + copy_j]
return main
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = T.bfloat16
M, K = A.shape
E, N, QK = qB.shape
topk = topk_weights.shape[0] // M
scale_size = K // Scale.shape[2]
assert scale_size == 32 # MXFP4
# Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda")
# Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M
token_id = sorted_token_ids[idx]
if token_id == -1:
continue
expert_id = expert_ids[idx // block_M]
topk_idx = token_id % topk
# Get the token embedding
token_embedding = A[token_id // topk]
# Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16))
# Compute the output for this token-expert pair
# token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight
weight = topk_weights[token_id]
output = output * weight
# Store the result
C[token_id // topk, topk_idx] = output
return C
def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda")
Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token.
# For each of m tokens, topk unique experts are chosen. Shape: (m * topk,)
topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1)
tokens_experts = tokens_experts.reshape(m * topk)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.reshape(m * topk)
sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True)
sorted_token_ids = sorted_indices
unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True)
expert_ids = []
padded_token_ids = []
start = 0
for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()):
end = start + cnt
group_token_ids = sorted_token_ids[start:end]
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")])
padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end
# sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False):
# Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841
threads = 512 # noqa: F841
split = 1 # noqa: F841
total_flops = 2 * m * n * k * topk
num_bits = 4
num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M)
if tune:
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
# Autotune with inputs manually composed
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
else:
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
)
output = kernel(
A,
qB,
Scale,
Bias,
topk_weights,
sorted_token_ids,
expert_ids,
)
print("Tilelang kernel run finished.")
ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
diff = (output - ref_output).abs()
max_val = diff.max()
max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--N", type=int, default=5760, help="N")
parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune)
import tilelang.testing
import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_groupedgemm_bf16_mxfp4_hopper
import example_dequant_gemm_w4a8
@tilelang.testing.requires_cuda
def test_example_dequant_gemv_fp16xint4():
example_dequant_gemv_fp16xint4.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_mxfp4_hopper():
example_dequant_gemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
example_dequant_gemm_bf16_mxfp4_hopper_tma.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()
if __name__ == "__main__":
tilelang.testing.main()
from typing import Optional
import torch
import torch.nn.functional as F
from indexer_topk_reducesum import indexer_topk_reducesum_interface
from indexer_bwd import indexer_bwd_interface
from sparse_mla_fwd import sparse_mla_fwd_interface
from sparse_mla_bwd import sparse_mla_bwd
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
from einops import einsum, repeat
from utils import get_abs_err, get_err_ratio
class RegsiterLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
ctx.save_for_backward(loss)
return x
@staticmethod
def backward(ctx, grad):
loss = ctx.saved_tensors
return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)
register_loss = RegsiterLossFunction.apply
def ref_deepseek_sparse_attention_innner(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
dtype = q.dtype
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights))
index_sm_scale = index_q.shape[-1] ** -0.5
b, s = index_q.shape[:2]
# tl_topk_indices = tl_topk_indices.to(torch.int64)
# tl_topk_indices[tl_topk_indices == -1] = s
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2")
index_logits = F.relu(index_logits)
index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale
index_logits = torch.where(casual_mask, index_logits, float("-inf"))
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices)
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
index_topk_score = topk_score
if sm_scale is None:
sm_scale = kv.shape[-1] ** -0.5
h = q.shape[-2]
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_(
dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool)
)[:, :, :-1]
mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h)
k, v = kv, kv[..., :dim_v]
logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d")
attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum")
o = register_loss(o, loss)
return o.to(dtype), topk_indices
def ref_deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
all_o, all_topk_indices = [], []
for i in range(offsets.shape[0] - 1):
o, topk_indices = ref_deepseek_sparse_attention_innner(
q[None, offsets[i] : offsets[i + 1]],
kv[None, offsets[i] : offsets[i + 1]],
index_q[None, offsets[i] : offsets[i + 1]],
index_k[None, offsets[i] : offsets[i + 1]],
weights[None, offsets[i] : offsets[i + 1]],
topk,
dim_v,
sm_scale,
index_sm_scale,
)
all_o.append(o.squeeze(0))
all_topk_indices.append(topk_indices.squeeze(0))
o = torch.cat(all_o, dim=0)
topk_indices = torch.cat(all_topk_indices, dim=0)
return o, topk_indices
class DSAFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets)
o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets)
ctx.topk = topk
ctx.dim_v = dim_v
ctx.sm_scale = sm_scale
return o, topk_indices
@staticmethod
def backward(
ctx,
do: torch.Tensor,
_1: torch.Tensor,
):
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
attn_score = sparse_mla_topk_reducesum_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v
).squeeze(-2)
dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets)
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
def deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
index_D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()
o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
o.backward(do)
q_grad, q.grad = q.grad, None
kv_grad, kv.grad = kv.grad, None
index_q_grad, index_q.grad = index_q.grad, None
index_k_grad, index_k.grad = index_k.grad, None
weights_grad, weights.grad = weights.grad, None
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
ref_o.backward(do)
ref_q_grad, q.grad = q.grad, None
ref_kv_grad, kv.grad = kv.grad, None
ref_index_q_grad, index_q.grad = index_q.grad, None
ref_index_k_grad, index_k.grad = index_k.grad, None
ref_weights_grad, weights.grad = weights.grad, None
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}")
print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}")
print(
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
)
print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}")
print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}")
intersections = []
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
mask = trt_np != -1
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
intersections.append(len(intersection) / len(set_ref))
print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))
test_kernel()
# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
import torch
import torch.nn.functional as F
import functools
from typing import Callable, Any
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: tuple | None = None
last_kwargs: dict | None = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if (
(last_args is not None and last_kwargs is not None)
and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs))
and all(a is b for a, b in zip(args, last_args, strict=False))
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items())
):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_cu_seqlens_from_lens(
lens: torch.LongTensor,
dtype: torch.dtype | None = torch.int32,
) -> torch.LongTensor:
return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0))
@tensor_cache
def prepare_lens_from_cu_seqlens(
cu_seqlens: torch.LongTensor,
) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()])
@tensor_cache
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
@tensor_cache
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
position_ids = prepare_position_ids(cu_seqlens)
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
import torch
import torch.nn.functional as F
from einops import einsum, repeat
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128,
):
assert num_stages == 0
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_I == 0
assert heads <= 64 and heads % 8 == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
dtype: str = BF16
accum_dtype: str = FP32
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
shape_p = [seq_len, topk]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
if sm_scale is None:
sm_scale = dim**-0.5
@T.prim_func
def tl_indexer_bwd_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos = Offsets[i_b]
num_blocks = T.ceildiv(topk, block_I)
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
weights_shared = T.alloc_shared([heads], dtype=dtype)
d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype)
d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.copy(Weights[bos + i_t, :], weights_shared)
T.fill(d_index_q_frag, 0)
T.fill(d_weights_frag, 0)
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
i_st = bi_i * block_I
i_ed = (bi_i + 1) * block_I
indices_shared = T.alloc_shared([block_I], dtype=INT32)
T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared)
index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0)
attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
for i in T.Parallel(block_I):
attn_score_shared[i] = AttnScore[bos + i_t, i_st + i]
index_score_shared[i] = IndexScore[bos + i_t, i_st + i]
logits = T.alloc_fragment((block_I, heads), accum_dtype)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i, j in T.Parallel(block_I, heads):
logits[i, j] = T.max(logits[i, j], 0)
# dw
d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
for i, j in T.Parallel(block_I, heads):
d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)
d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype)
d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype)
for i, j in T.Parallel(block_I, heads):
d_relu = T.alloc_var(accum_dtype)
if logits[i, j] > 0:
d_relu = 1.0
else:
d_relu = 0.0
d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j]
# dq
T.copy(d_logits_qk, d_logits_qk_cast1)
T.gemm(
d_logits_qk_cast1, # [BS, HQ]
index_k_shared, # [BS, K]
d_index_q_frag, # [HQ, K]
transpose_A=True,
transpose_B=False,
clear_accum=False,
)
# dk
T.copy(d_logits_qk, d_logits_qk_cast2)
d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype)
T.gemm(
d_logits_qk_cast2, # [BS, HQ]
index_q_shared, # [HQ, K]
d_index_k_frag, # [BS, K]
transpose_A=False,
transpose_B=False,
clear_accum=True,
)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
if (pos > -1) & (pos <= i_t):
T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])
for i, j in T.Parallel(heads, dim):
d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale
T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :])
T.copy(d_weights_frag, dWeights[bos + i_t, :])
return tl_indexer_bwd_kernel
def indexer_bwd_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
attn_score: torch.Tensor,
index_score: torch.Tensor,
topk_indices: torch.Tensor,
offsets: torch.Tensor,
):
_, heads, dim, topk = *q.shape, topk_indices.shape[-1]
token_indices = prepare_token_indices(offsets)
dq = torch.zeros_like(q)
dweights = torch.zeros_like(weights)
dk = torch.zeros_like(k)
kernel = tl_indexer_bwd_impl(heads, dim, topk)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices)
return dq, dweights, dk
def ref_indexer_bwd(
Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
Q.requires_grad_(True)
Weights.requires_grad_(True)
K.requires_grad_(True)
softmax_scale = Q.shape[-1] ** -0.5
all_loss = []
all_log_topk_prob = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
attn_score = AttnScore[offsets[i] : offsets[i + 1]]
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale
logits = F.relu(logits)
score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
score = torch.where(mask, score, float("-inf"))
topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum")
all_loss.append(loss)
all_log_topk_prob.append(log_topk_prob)
loss = torch.stack(all_loss).sum()
loss.backward()
log_topk_prob = torch.cat(all_log_topk_prob, dim=0)
return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad
def test_kernel(
B=1,
S=2048,
H=16,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
w = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
all_attn_score = []
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
logits = torch.ones(seq_len, topk).cuda()
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
all_attn_score.append(attn_score)
attn_score = torch.cat(all_attn_score, dim=0)
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets)
dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)
print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}")
print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}")
print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")
if __name__ == "__main__":
test_kernel()
import math
import torch
import torch.nn.functional as F
from einops import einsum
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_topk_reducesum_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_K: int = 32,
dtype: str = FP32,
num_stages: int = 0,
num_threads: int = 128,
):
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_K == 0
assert heads <= 64 and heads % 8 == 0
assert num_stages == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
N = 2 * topk
num_iters = int(round(math.log2(N)))
if sm_scale is None:
sm_scale = dim**-0.5
@T.macro
def bitonic_sort(
topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32),
):
T.sync_threads()
for i1 in T.serial(num_iters):
for i2 in T.serial(i1 + 1):
for i in T.Parallel(N):
ascending = (i & (1 << (i1 + 1))) != 0
j = i ^ (1 << (i1 - i2))
if i < j and (
(ascending and topk_value_shared[i] > topk_value_shared[j])
or (not ascending and topk_value_shared[i] < topk_value_shared[j])
):
val = topk_value_shared[i]
topk_value_shared[i] = topk_value_shared[j]
topk_value_shared[j] = val
idx = topk_index_shared[i]
topk_index_shared[i] = topk_index_shared[j]
topk_index_shared[j] = idx
T.sync_threads()
@T.prim_func
def tl_indexer_topk_reducesum_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos, eos = Offsets[i_b], Offsets[i_b + 1]
num_blocks = T.ceildiv(i_t + 1, block_K)
topk_index_shared = T.alloc_shared([N], dtype=INT32)
topk_value_shared = T.alloc_shared([N], dtype=FP32)
T.fill(topk_index_shared, -1)
T.fill(topk_value_shared, float("-inf"))
T.sync_threads()
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.sync_threads()
weights_frag = T.alloc_shared([heads], dtype=dtype)
T.copy(Weights[bos + i_t, :], weights_frag)
T.sync_threads()
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
T.sync_threads()
for bk_i in T.Pipelined(num_blocks, num_stages=num_stages):
k_st = bk_i * block_K
k_ed = T.min((bk_i + 1) * block_K, eos - bos)
index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype)
for i, j in T.Parallel(block_K, dim):
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0)
T.sync_threads()
logits = T.alloc_fragment((block_K, heads), FP32)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
T.sync_threads()
for i, j in T.Parallel(block_K, heads):
logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j]
T.sync_threads()
logits_sum = T.alloc_fragment(block_K, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
T.sync_threads()
offset = T.alloc_var(INT32)
if k_st >= topk:
offset = topk + (k_st % topk)
else:
offset = k_st
T.sync_threads()
for i in T.Parallel(block_K):
if k_st + i > i_t:
logits_sum[i] = float("-inf")
j = offset + i
topk_index_shared[j] = k_st + i
topk_value_shared[j] = logits_sum[i]
T.sync_threads()
if k_ed > topk and k_ed % topk == 0:
bitonic_sort(topk_index_shared, topk_value_shared)
bitonic_sort(topk_index_shared, topk_value_shared)
logits_max_frag = T.alloc_fragment([1], dtype=FP32)
logits_frag = T.alloc_fragment([topk], dtype=FP32)
reducesum_shared = T.alloc_shared([topk], dtype=FP32)
T.copy(topk_value_shared[:topk], logits_frag)
T.sync_threads()
T.reduce_max(logits_frag, logits_max_frag, dim=-1)
T.sync_threads()
for i in T.Parallel(topk):
logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0])
T.sync_threads()
lse_frag = T.alloc_fragment([1], dtype=FP32)
T.reduce_sum(logits_frag, lse_frag)
T.sync_threads()
for i in T.Parallel(topk):
reducesum_shared[i] = logits_frag[i] / lse_frag[0]
T.sync_threads()
# for i in T.Parallel(topk):
# reducesum_shared[i] = logits_frag[i]
# T.sync_threads()
for i in T.Parallel(topk):
if topk_index_shared[i] > i_t:
topk_index_shared[i] = -1
T.sync_threads()
T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :])
T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :])
return tl_indexer_topk_reducesum_kernel
def indexer_topk_reducesum_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
topk: int,
offsets: torch.Tensor,
dtype: str = BF16,
):
seq_len, heads, dim = q.shape
kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype)
token_indices = prepare_token_indices(offsets)
topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32)
topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32)
kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices)
return topk_indices, topk_score
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor:
all_topk_indices = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= topk
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
softmax_scale = q.shape[-1] ** -0.5
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2")
logits = F.relu(logits)
logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale
logits = torch.where(mask, logits, float("-inf"))
topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1)
topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32)
all_topk_indices.append(topk_indices)
all_topk_score.append(topk_score)
topk_indices = torch.cat(all_topk_indices, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return topk_indices, topk_score
def test_kernel(
B=1,
S=2048,
H=64,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
weights = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, S], dtype=torch.int32).cuda()
ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets)
topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets)
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
ref_np_val = ref_topk_score[j]
trt_np_val = topk_score[j]
mask = (ref_np_val > 0).cpu().numpy()
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}")
if __name__ == "__main__":
test_kernel()
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(out_idx=[-1])
def preprocess(
H,
D,
block_ND=32,
num_stages=5,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S = T.symbolic("S")
shape = [S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
delta = T.alloc_fragment([block_ND], accum_dtype)
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel
@tilelang.jit(out_idx=[-1])
def postprocess(
D,
D_tail,
kv_group=1,
block_N=64,
threads=128,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S_kv = T.symbolic("S_kv")
dkv_shape = [S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
T.copy(
dKV[bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bx * block_N : (bx + 1) * block_N, by, :],
)
return postprocess_kernel
@tilelang.jit(
out_idx=[-2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def bwd(
H,
D,
D_tail,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
block_size=32,
num_stages=0,
threads=128,
indices_dtype=T.int32,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
assert indices_dtype == T.int32
if sm_scale is None:
sm_scale = (D + D_tail) ** (-0.5)
B_plus_one = T.symbolic("B_plus_one")
S = T.symbolic("S")
H_kv = H // kv_group
q_shape = [S, H, D + D_tail]
k_shape = [S, kv_group, D + D_tail]
o_shape = [S, H, D]
indices_shape = [S, kv_group, topk]
delta_shape = [S, H]
lse_shape = [S, H]
offsets_shape = [B_plus_one]
token_indices_shape = [S, 2]
assert indices_dtype == T.int32
assert dtype == T.bfloat16
assert accum_dtype == T.float32
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
BS = block_size
NS = tilelang.cdiv(topk, block_size)
split_store = 2
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
KV_shared = T.alloc_shared([BS, D], dtype)
KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
dO_shared = T.alloc_shared([padded_H, D], dtype)
mask = T.alloc_fragment([BS], "bool")
P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dQ_shared = T.alloc_shared([padded_H, D], dtype)
dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
max_kv_i = s_i
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout(
{
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
if bi_i < BS // split_store:
acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4],
)
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4],
)
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
assert lse.is_contiguous()
S, H, dim_plus_tail_dim = q.shape
S_kv, kv_group, _ = kv.shape
assert kv.shape[-1] == dim_plus_tail_dim
assert S == S_kv
# dim should be assigned
D = 512
D_tail = dim_plus_tail_dim - D
topk = indices.shape[-1]
assert indices.shape == (S, kv_group, topk)
assert lse.shape == (S, H)
token_indices = prepare_token_indices(offsets)
# Get kernels
preprocess_kernel = preprocess(H, D)
bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual)
postprocess_kernel = postprocess(D, D_tail, kv_group)
if delta is None:
delta = preprocess_kernel(o, do)
dkv = torch.zeros_like(kv, dtype=torch.float32)
dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv)
dkv = postprocess_kernel(dkv)
return dq, dkv
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
kv.requires_grad = True
o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual)
o.backward(do)
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True):
# Prepare data
q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device="cuda")
offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, : len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets)
if check_correctness:
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum(
[
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
]
)
from tilelang.profiler import do_bench
def fn():
return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True)
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
else:
sm_scale = sm_scale
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
head_kv = heads // kv_group
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len, kv_group, dim + tail_dim]
o_shape = [seq_len, heads, dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, Output[bos + s_i, H0:H1, :])
T.copy(sumexp, Lse[bos + s_i, H0:H1])
return main
def sparse_mla_fwd_interface(
q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128
):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
seq_len, heads, dim_plus_tail_dim = q.shape
seq_len_kv, kv_group, _ = kv.shape
assert seq_len == seq_len_kv
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
_, _, topk = indices.shape
assert indices.shape == (seq_len, kv_group, topk)
token_indices = prepare_token_indices(offsets)
kernel = sparse_mla_fwd(
heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
)
out, lse = kernel(q, kv, indices, offsets, token_indices)
return out, lse
def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True):
Q = Q.float()
KV = KV.float()
all_o = []
for i in range(offsets.shape[0] - 1):
q = Q[None, offsets[i] : offsets[i + 1]]
kv = KV[None, offsets[i] : offsets[i + 1]]
indices = Indices[None, offsets[i] : offsets[i + 1]].clone()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
).view(1, -1)
indices[indices > sk] = sk
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
all_o.append(o.squeeze(0))
o = torch.cat(all_o, dim=0)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
):
torch.random.manual_seed(0)
q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=1024,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
)
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
@T.prim_func
def tl_sparse_mla_topk_reducesum_kernel(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
reducesum = T.alloc_fragment([BI], accum_dtype)
lse = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(lse, 0)
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
r_i = bx % REPLICATE_H
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
T.copy(Lse[bos + s_i, H0:H1], lse)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
T.reduce_sum(acc_s, reducesum, dim=0)
T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI])
return tl_sparse_mla_topk_reducesum_kernel
def sparse_mla_topk_reducesum_interface(
q: torch.Tensor,
kv: torch.Tensor,
topk_indices: torch.Tensor,
lse: torch.Tensor,
offsets: torch.Tensor,
dim_v: int,
):
assert kv.shape[-2] == 1
seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
REPLICATE_H = max(heads // 64, 1)
tail_dim = dim_plus_tail_dim - dim_v
token_indices = prepare_token_indices(offsets)
reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)
return attn_score
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor):
# q: [batch, seq_len, heads, dim]
# k: [batch, seq_len, dim]
sm_scale = Q.shape[-1] ** -0.5
all_lse = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
q = Q[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
seq_len = q.shape[0]
mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
score = F.softmax(logits, dim=-1, dtype=torch.float32)
score_sum = score.sum(dim=-2)
topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
max_logits = logits.amax(dim=-1).to(torch.float32)
lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
all_lse.append(lse)
all_topk_score.append(topk_score)
lse = torch.cat(all_lse, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return lse, topk_score
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
topk=128,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)
kv = kv.unsqueeze(-2)
topk_indices = topk_indices.unsqueeze(-2)
attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}")
if __name__ == "__main__":
test_kernel()
import torch
def get_abs_err(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
return (x - y).flatten().abs().max().item()
def get_err_ratio(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
err = (x - y).flatten().square().mean().sqrt().item()
base = (x).flatten().square().mean().sqrt().item()
return err / base
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
if raise_assert:
assert False # noqa: B011
import argparse
import itertools
import torch
import tilelang
import tilelang.language as T
def ref_program(x, y):
return x + y
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
def main(M=1024, N=1024, use_autotune=False):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if use_autotune:
kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32)
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
main(args.m, args.n, args.use_autotune)
import tilelang.testing
import example_elementwise_add
def test_example_elementwise_add():
example_elementwise_add.main()
def test_example_elementwise_add_autotune():
example_elementwise_add.main(use_autotune=True)
if __name__ == "__main__":
tilelang.testing.main()
# FlashAttention
Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way.
```python
@T.prim_func
def flash_attention(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
# Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
# bx: block index in sequence dimension
# by: block index in "heads" dimension
# bz: block index in "batch" dimension
# threads=thread_num means how many threads per block
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
# Allocate shared memory for Q, K, V to reduce global memory accesses
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Allocate buffers on register
# acc_s: buffer to hold intermediate attention scores
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
# acc_s_cast: buffer for storing casted/adjusted scores
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
# acc_o: partial accumulation of output
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
# Buffers to track per-row maximum score and related stats
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
# Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
# Copy a block of Q from global memory to Q_shared
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
# Initialize accumulators
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
)
# Pipeline the loop to overlap copies/gemm stages
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Copy K block into shared memory
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
)
else:
T.clear(acc_s)
# Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed,
# policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row
# of acc_s, and the resulting acc_s is retained in registers.
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Copy V block into shared memory
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
for i, j in T.Parallel(block_M, dim):
acc_s[i, j] *= scale
# Save old scores_max, then reset scores_max
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
# Rescale the partial output accumulation to keep exponents consistent
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# Exponentiate (scores - max) for the new block
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
# Make a cast of acc_s to fp16 for the next GEMM
T.copy(acc_s, acc_s_cast)
# Multiply the attention acc_s_cast by V and add to partial output (acc_o)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
# Update the "logsum" tracker with the newly accumulated sum
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
# Final step: divide each partial output by logsum (completing the softmax)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
# Write back the final output block from acc_o to the Output buffer
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
```
\ No newline at end of file
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# ruff: noqa
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import argparse
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape, lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(
out_idx=[1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_qk]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk : (bx + 1) * blk, by, :],
dQ_out[bz, bx * blk : (bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared)
return flash_bwd
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
(
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD_QK)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.contrib import nvcc
import argparse
tilelang.disable_cache()
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, groups=1):
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
K_shared = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_N, dim_v], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim_v], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
# Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops
# We should set it to negative large number instead
T.fill(scores_max, T.Cast(accum_dtype, -1e30))
loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_range, num_stages=1):
T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30))
else:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim_v):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M])
return flash_fwd
@tilelang.jit(
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim_v, blk)):
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# bshd -> bhld to use tma reduction instruction
return T.Layout(dQ.shape, lambda b, l, h, d: [b, h, l, d])
@tilelang.jit(
out_idx=[3, 4, 5],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = T.float16
accum_dtype = T.float32
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(q_shape, dtype), # type: ignore
dK_out: T.Tensor(k_shape, dtype), # type: ignore
dV_out: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(dQ[bz, bx * blk : (bx + 1) * blk, by, :], dQ_out[bz, bx * blk : (bx + 1) * blk, by, :])
with T.Kernel(T.ceildiv(seq_len, blk), head_kv, batch, threads=128) as (bx, by, bz):
T.annotate_layout(
{
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
}
)
T.copy(dK[bz, bx * blk : (bx + 1) * blk, by, :], dK_out[bz, bx * blk : (bx + 1) * blk, by, :])
T.copy(dV[bz, bx * blk : (bx + 1) * blk, by, :], dV_out[bz, bx * blk : (bx + 1) * blk, by, :])
return flash_bwd_post
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype)
dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
dK: make_dq_layout(dK),
dV: make_dq_layout(dV),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
T.copy(dq, dq_shared)
T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared, use_tma=True)
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared, use_tma=True)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared, use_tma=True)
return flash_bwd
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, block_N, threads=256, num_stages=2, groups=1):
sm_scale = (1.0 / dim_qk) ** 0.5
scale = (1.0 / dim_qk) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim_qk], dtype)
V_shared = T.alloc_shared([block_M, dim_v], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim_v], dtype)
dv = T.alloc_fragment([block_M, dim_v], accum_dtype)
dk = T.alloc_fragment([block_M, dim_qk], accum_dtype)
dq = T.alloc_fragment([block_N, dim_qk], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim_v], dtype)
dk_shared = T.alloc_shared([block_M, dim_qk], dtype)
T.annotate_layout(
{
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared)
T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim_qk):
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
T.copy(dk, dk_shared)
T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
return flash_bwd
@torch.compile
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, groups=1, use_atomic=True):
BATCH, N_CTX, H, D_HEAD_QK = q.shape
D_HEAD_V = v.shape[-1]
block_M = 128
block_N = 64
mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD_QK = q.shape
(
HEAD_KV,
D_HEAD_V,
) = v.shape[-2], v.shape[-1]
groups = H // HEAD_KV
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 32
mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V)
mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V)
delta = mod_prep(o, do)
if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [BATCH, N_CTX, HEAD_KV, D_HEAD_QK]
shape_v = [BATCH, N_CTX, HEAD_KV, D_HEAD_V]
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, dk, dv = mod_post(dq, dk, dv)
else:
kernel = flashattn_bwd_split_novarlen(
BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups
)
shape_q = [BATCH, N_CTX, H, D_HEAD_QK]
shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel
shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel
dq = torch.zeros(shape_q, dtype=torch.float32, device=q.device)
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D_QK]
# K: [B, T, HK, D_QK]
# V: [B, T, HV, D_V]
# HQ = HKV * groups
assert Q.size(2) == K.size(2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim_qk = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim_qk, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output
def main(
BATCH: int = 1,
H: int = 32,
N_CTX: int = 256,
D_HEAD_QK: int = 192,
D_HEAD_V: int = 128,
groups: int = 16,
causal: bool = False,
use_atomic: bool = True,
):
flops_per_qk = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_QK
flops_per_v = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD_V
total_flops = 3 * flops_per_qk + 2 * flops_per_v
if causal:
total_flops *= 0.5
Q = torch.empty(BATCH, N_CTX, H, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
head_kv = H // groups
K = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_QK, dtype=torch.half, device="cuda").normal_().requires_grad_()
V = torch.empty(BATCH, N_CTX, head_kv, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
dO = torch.empty(BATCH, N_CTX, H, D_HEAD_V, dtype=torch.half, device="cuda").normal_().requires_grad_()
O = attention(Q, K, V, causal, groups, use_atomic)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal, groups)
O_ref.backward(dO, retain_graph=True)
dQ_ref, Q.grad = Q.grad.clone(), None
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run():
O_ref.backward(dO, retain_graph=True)
def run1():
O.backward(dO, retain_graph=True)
from tilelang.profiler import do_bench
latency = do_bench(run, warmup=500)
print("torch: {:.2f} ms".format(latency))
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = do_bench(run1, warmup=500)
print("tilelang: {:.2f} ms".format(latency))
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
arch = nvcc.get_target_compute_version()
print(f"Detected GPU compute capability: {arch}")
assert float(arch) >= 9.0, "This example only supports GPU with compute capability >= 9.0"
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--h", type=int, default=32, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size")
parser.add_argument("--d_head_qk", type=int, default=192, help="Head dimension for Q/K")
parser.add_argument("--d_head_v", type=int, default=128, help="Head dimension for V")
parser.add_argument("--causal", action="store_true", help="Causal flag")
parser.add_argument("--groups", type=int, default=16, help="groups")
parser.add_argument("--use_atomic", action="store_true", default=False, help="Use atomic add for dK/dV")
parser.add_argument("--use_split", action="store_true", default=False, help="Use split for dK/dV")
args = parser.parse_args()
# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True
main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, use_atomic)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment