"src/vscode:/vscode.git/clone" did not exist on "1940b3c90b43a5ce522e611de46660e40c6f013b"
Commit bf824406 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Support larger than 256 box size tma copy (#413)

* [New Feature] Add FP8 Flash Attention Implementation (#412)

* Introduce a new example script for FP8 Flash Attention in `example_mla_decode_kv_fp8.py`, showcasing the use of tilelang for efficient attention computation.
* Implement the `flashattn` function with optimized memory management and kernel execution.
* Include a reference program for comparison and performance evaluation.
* Add command-line argument parsing for batch size, number of heads, and dimensions to facilitate testing and experimentation.
* Enhance the overall structure and readability of the code.

This addition aims to improve the performance of attention mechanisms in deep learning models by leveraging FP8 precision and optimized kernel execution.

* lint fix

* optimize quick start

* lint fix
parent f221ff39
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
q_dtype = "e4m3_float8"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
qKV_shared = T.alloc_shared([block_N, dim], q_dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.NoSetMaxNReg()
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(qKV_shared, KV_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
return main_no_split
def ref_program(q, q_pe, kv, k_pe):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H)
print(program)
kernel = tilelang.compile(program, out_idx=-1)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
......@@ -55,8 +55,15 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(1024, 1024, 1024, 128, 128, 32)
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
......@@ -69,8 +76,8 @@ jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backen
import torch
# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
......
......@@ -196,6 +196,13 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B)) {
instruction_dim = 128 / src->dtype.bytes();
}
if (instruction_dim > 256) {
// smem_box dim must be in [0, 256]
// if is 512, we need to split the copy into two parts
ICHECK((*inner_box_dim) % 256 == 0)
<< "inner_box_dim: " << *inner_box_dim << " is not divisible by 256";
instruction_dim = 256;
}
ICHECK((*inner_box_dim) % instruction_dim == 0);
desc.smem_box.Set(0, PrimExpr(instruction_dim));
......
......@@ -48,9 +48,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
# TODO(lei): may need a pass to fuse the if-then-else in the
# pipeline loop when we meet dynamic branch.
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
......
......@@ -64,22 +64,10 @@ TMA_DESC_INIT_FUNC = """
\tif ({0}_result != CUDA_SUCCESS) {{
\t\tstd::stringstream ss;
\t\tss << "TMA Desc Addr: " << &{0}
\t\t\t<< "\\nformat " << {0}_type
\t\t\t<< "\\ndim " << {0}_tensorRank
\t\t\t<< "\\ngmem_address " << {0}_globalAddress
\t\t\t<< "\\nglobalDim " << {0}_globalDim
\t\t\t<< "\\nglobalStrides " << {0}_globalStride + 1
\t\t\t<< "\\nboxDim " << {0}_boxDim
\t\t\t<< "\\nelementStrides " << {0}_elementStrides
\t\t\t<< "\\ninterleave " << {0}_interleave
\t\t\t<< "\\nswizzle " << {0}_swizzle
\t\t\t<< "\\nl2Promotion " << {0}_l2Promotion
\t\t\t<< "\\noobFill " << {0}_oobFill
\t\t\t<< "\\nError: Failed to initialize the TMA descriptor {0}";
\t\tss << "Error: Failed to initialize the TMA descriptor {0}";
\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
\t\treturn -1;
}}
\t}}
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment