Commit d110d087 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.
parent 78056597
......@@ -29,7 +29,7 @@ def ref_program(A, B):
return A @ B.T
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -44,6 +44,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K, with_roller = args[:4]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -86,40 +88,40 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[32, 64],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs
def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -146,118 +148,66 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
ref_prog=ref_program,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return kernel()
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
if __name__ == "__main__":
......
......@@ -162,7 +162,7 @@ def ref_program(A, B):
return A @ B.T
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -177,6 +177,9 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K = args[:3]
with_roller = args[6]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -218,62 +221,49 @@ def get_configs(M, N, K, with_roller=False):
print(config)
else:
block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
warp_row_tiles=[16, 32, 64, 128],
warp_col_tiles=[16, 32, 64, 128],
chunk=[32, 64, 128, 256],
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs
def matmul(M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
@autotune(
configs=get_configs,
warmup=3,
rep=5,
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False,
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
"""Create an autotuned tensor core matrix multiplication kernel."""
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@tl.jit(out_idx=[2],)
def kernel(
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
warp_col_tiles=None,
chunk=None,
stage=None,
enable_rasteration=None,
):
def kernel():
return tl_matmul(
M,
N,
......
......@@ -30,7 +30,7 @@ def ref_program(A, B):
return A.float() @ B.T.float()
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -45,6 +45,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K, with_roller = args[:4]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -87,40 +89,41 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -147,117 +150,66 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "e4m3_float8"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "e4m3_float8"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return kernel()
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A[by * block_M, k * block_K], A_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
if __name__ == "__main__":
......
......@@ -337,10 +337,6 @@ def get_best_config(N, K):
@autotune(
configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3,
rep=20,
)
......
......@@ -223,6 +223,7 @@ def convolution(N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
......@@ -291,14 +292,14 @@ def main(n: int = 128,
with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
use_autotune = True
if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2])
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
......
......@@ -57,7 +57,18 @@ def get_configs(user_config=None):
return valid_configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
......@@ -65,142 +76,119 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return kernel
return main
def ref_program(Q, K, V, is_causal, groups=1):
......@@ -242,8 +230,16 @@ def main(batch: int = 1,
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=64, block_N=64, num_stages=2, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -255,10 +251,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -9,22 +9,33 @@ from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
iter_params = dict(
block_M=[128],
block_N=[128],
num_stages=[2],
threads=[256],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=10,
rep=10,
)
@tilelang.jit(out_idx=[3])
def flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128,
):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
......@@ -32,147 +43,124 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return kernel
return main
def ref_program(Q, K, V, is_causal, groups=1):
......@@ -216,8 +204,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -229,10 +225,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -6,166 +6,147 @@ import tilelang.language as T
import itertools
import argparse
from functools import partial
from tilelang import jit
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return kernel
return main
def ref_program(Q, K, V, is_causal):
......@@ -199,8 +180,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=64, block_N=64, num_stages=1, threads=128)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
......@@ -213,10 +202,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -6,171 +6,152 @@ import tilelang.language as T
import itertools
import argparse
from functools import partial
from tilelang import jit
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
past_len = seq_kv - seq_q
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return kernel
return main
def ref_program(Q, K, V, is_causal):
......@@ -204,8 +185,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
......@@ -218,10 +207,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -9,159 +9,139 @@ from functools import partial
def get_configs():
block_M = [64]
block_N = [64]
num_stages = [1]
threads = [128]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return kernel
return main
def ref_program(Q, K, V, is_causal):
......@@ -193,8 +173,15 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=1, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......
......@@ -9,164 +9,144 @@ from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(
loop_range,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return kernel
return main
def ref_program(Q, K, V, is_causal):
......@@ -198,8 +178,15 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -211,10 +198,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -54,7 +54,10 @@ def get_pass_configs():
return {}
def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages,
threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim]
......@@ -64,260 +67,224 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
accum_dtype = "float"
kv_group_num = heads // groups
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def kernel_func(block_N, block_H, num_split, num_stages, threads):
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
T.fill(K_shared, 0)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head, :], K_shared)
T.copy(
mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
(mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
mask_local = T.alloc_fragment([block_N], "uint8")
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
T.fill(K_shared, 0)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], K_shared)
T.copy(
mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i,
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local:
T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local:
T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
for i in T.Parallel(block_H):
if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn(Q, K, V, mask, Output)
if num_split > 1:
return flashattn_gqa_decode_split
else:
return flashattn_gqa_decode_no_split
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(
out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
max_mismatched_ratio=0.05)
def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel()
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn(Q, K, V, mask, Output)
if num_split > 1:
return flashattn_gqa_decode_split
else:
def kernel(block_N, block_H, num_split, num_stages, threads):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel
return flashattn_gqa_decode_no_split
def ref_program(query, key, value, mask, glse, Output_partial):
......@@ -485,7 +452,7 @@ def main(batch: int = 1,
if (not tune):
config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config)
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
......@@ -513,10 +480,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -219,17 +219,10 @@ def splitk_gemv_vectorized_tvm(
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
return [
dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())
]
@autotune(
configs=get_configs(),
......
......@@ -61,158 +61,139 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def get_configs():
block_M = [64, 128, 256]
block_N = [32, 64]
block_K = [64, 128, 256]
block_Dstate = [128]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'block_Dstate': c[3],
'num_stages': c[4],
'threads': c[0] * 2
} for c in _configs]
return configs
iter_params = dict(
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
def chunk_scan_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):
@T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), C: T.Tensor(
(batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Tensor(
(batch, seqlen, nheads, headdim), dtype)):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
cb_local = T.alloc_fragment((block_M, block_K), dtype)
dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
dt_local = T.alloc_fragment((block_K), accum_dtype)
x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
scale_m_local = T.alloc_fragment((block_M), accum_dtype)
C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
D_local = T.alloc_fragment((1), accum_dtype)
x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)
batch_idx = by % batch
chunk_idx = by // batch
# m: chunk_size
# n : headdim
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
@T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype),
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn")
cb_local = T.alloc_fragment((block_M, block_K), dtype)
dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared")
dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype)
dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype)
dt_shared = T.alloc_shared((block_K), dtype, scope="shared")
dt_local = T.alloc_fragment((block_K), accum_dtype)
x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn")
dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared")
scale_m_local = T.alloc_fragment((block_M), accum_dtype)
C_shared = T.alloc_shared((block_M, block_Dstate), dtype)
prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype)
D_local = T.alloc_fragment((1), accum_dtype)
x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn")
x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype)
batch_idx = by % batch
chunk_idx = by // batch
# m: chunk_size
# n : headdim
m_idx = bx // T.ceildiv(headdim, block_N)
n_idx = bx % T.ceildiv(headdim, block_N)
T.annotate_layout({
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared),
cb_shared: tilelang.layout.make_swizzled_layout(cb_shared),
x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared)
})
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M],
dA_cs_m_shared)
T.copy(dA_cs_m_shared, dA_cs_m_local)
T.clear(acc_o)
for i in T.Parallel(block_M):
scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p)
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared)
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
T.copy(
prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N,
0:block_Dstate], prev_state_shared)
T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] *= scale_m_local[i]
loop_range = T.ceildiv((m_idx + 1) * block_M, block_K)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
cb[batch_idx, chunk_idx, bz // (nheads // ngroups),
m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K],
cb_shared)
T.copy(cb_shared, cb_local)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p -
dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] *= dt_local[j]
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j,
cb_local[i, j], 0)
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size +
m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz,
n_idx * block_N:(n_idx + 1) * block_N])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
def kernel(block_M=None,
block_N=None,
block_K=None,
block_Dstate=None,
num_stages=None,
threads=None):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel()
else:
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared)
T.gemm(cb_local, x_shared, acc_o)
D_local[0] = D[bz]
T.copy(
x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N],
x_residual_shared)
T.copy(x_residual_shared, x_residual_local)
for i, j in T.Parallel(block_M, block_N):
acc_o[i, j] += x_residual_local[i, j] * D_local[0]
def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
return kernel
return main
if __name__ == "__main__":
......@@ -231,8 +212,19 @@ if __name__ == "__main__":
if (not args.tune):
kernel = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128)
batch,
seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -243,11 +235,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -46,110 +46,94 @@ def ref_program(B, x, dt, dA_cumsum):
def get_configs():
block_M = [64, 128]
block_N = [32, 64, 128]
block_K = [32, 64]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[0] * 2
} for c in _configs]
return configs
iter_params = dict(
block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
def chunk_state_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel(
nheads,
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
dt_shared = T.alloc_shared((block_K), dtype)
dA_cumsum_shared = T.alloc_shared((block_K), dtype)
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
scale = T.alloc_fragment((block_K), accum_dtype)
dA_cs_last = T.alloc_fragment((1), accum_dtype)
dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype)
dt_local = T.alloc_fragment((block_K), accum_dtype)
loop_range = T.ceildiv(chunk_size, block_K)
batch_idx = by % batch
chunk_idx = by // batch
m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N)
T.annotate_layout({
x_shared: tilelang.layout.make_swizzled_layout(x_shared),
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
})
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K):
scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i]
T.copy(x_shared, x_local)
for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j]
T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz // (nheads // ngroups),
n_idx * block_N:(n_idx + 1) * block_N], B_shared)
T.gemm(xt_local, B_shared, acc_o)
T.copy(acc_o, acc_o_shared)
@T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel(
nheads,
T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
batch * nchunks,
threads=threads) as (bz, bx, by):
x_shared = T.alloc_shared((block_K, block_M), dtype)
x_local = T.alloc_fragment((block_K, block_M), dtype)
xt_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
dt_shared = T.alloc_shared((block_K), dtype)
dA_cumsum_shared = T.alloc_shared((block_K), dtype)
acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
scale = T.alloc_fragment((block_K), accum_dtype)
dA_cs_last = T.alloc_fragment((1), accum_dtype)
dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype)
dt_local = T.alloc_fragment((block_K), accum_dtype)
loop_range = T.ceildiv(chunk_size, block_K)
batch_idx = by % batch
chunk_idx = by // batch
m_idx = bx // T.ceildiv(dstate, block_N)
n_idx = bx % T.ceildiv(dstate, block_N)
T.annotate_layout({
x_shared: tilelang.layout.make_swizzled_layout(x_shared),
acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
})
dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
T.clear(acc_o)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
n_idx * block_N:(n_idx + 1) * block_N])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
dA_cumsum_shared)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dA_cumsum_shared, dA_cumsum_local)
T.copy(dt_shared, dt_local)
for i in T.Parallel(block_K):
scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i]
T.copy(x_shared, x_local)
for i, j in T.Parallel(block_M, block_K):
xt_local[i, j] = x_local[j, i] * scale[j]
T.copy(
B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
(k + 1) * block_K, bz // (nheads // ngroups),
n_idx * block_N:(n_idx + 1) * block_N], B_shared)
T.gemm(xt_local, B_shared, acc_o)
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
n_idx * block_N:(n_idx + 1) * block_N])
return kernel
return main
if __name__ == "__main__":
......@@ -168,8 +152,18 @@ if __name__ == "__main__":
if (not args.tune):
kernel = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=128, block_K=64, num_stages=4, threads=128)
batch,
seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=128,
block_K=64,
num_stages=4,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -180,8 +174,7 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
......
import torch
import torch.nn.functional as F
import tilelang
from tilelang import cached
import tilelang.language as T
import tilelang.testing
......@@ -9,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[3, 4],)
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
......@@ -78,6 +78,7 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
return flash_fwd
@tilelang.jit(out_idx=[2],)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -113,6 +114,7 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1],)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -134,11 +136,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post
@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
@tilelang.jit(out_idx=[7, 8])
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......@@ -161,10 +159,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -182,7 +176,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
......@@ -237,8 +230,8 @@ class _attention(torch.autograd.Function):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
mod = cached(flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N), [3, 4])
o, lse = mod(q, k, v)
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = kernel(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
......@@ -256,13 +249,13 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2])
mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1])
delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
delta = kernel_prep(o, do)
dq = torch.zeros_like(q, dtype=torch.float32)
dk, dv = mod(q, k, v, do, lse, delta, dq)
dq = mod_post(dq)
dk, dv = kernel(q, k, v, do, lse, delta, dq)
dq = kernel_post(dq)
return dq, dk, dv, None
......@@ -307,8 +300,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd():
assert_mha_equal(8, 32, 128, 64, False)
assert_mha_equal(8, 32, 128, 64, True)
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__":
......
......@@ -263,6 +263,9 @@ class AutoTuner:
sig = inspect.signature(self.fn)
parameters = sig.parameters
if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters)
with self._lock:
......@@ -392,6 +395,31 @@ class AutoTuner:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs)
if len(config_args) == 0:
raise ValueError("No configurations to tune, please check your `@autotune` decorator")
# check if the tunable arguments has been set.
# get the back config argument
top_config, *rest = config_args
if self._kernel_parameters is not None:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
num_workers = max(1, int(get_available_cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
......@@ -502,7 +530,7 @@ class _AutoTunerImplementation:
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Any = None
configs: Union[Dict, Callable] = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
......@@ -514,7 +542,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False
def __init__(self,
configs: Any,
configs: Union[Dict, Callable],
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
......@@ -581,7 +609,6 @@ class _AutoTunerImplementation:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
configs = self.configs
@functools.wraps(fn)
def wrapper(*args, **kwargs):
......@@ -598,7 +625,7 @@ class _AutoTunerImplementation:
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
......@@ -634,7 +661,7 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
configs: Any,
configs: Union[Dict, Callable],
# profile arguments
warmup: int = 25,
rep: int = 100,
......@@ -655,6 +682,16 @@ def autotune( # This is the new public interface
This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python
if enable_autotune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
else:
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
```
Parameters
----------
......@@ -662,6 +699,13 @@ def autotune( # This is the new public interface
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
configs : Dict or Callable
Configuration space to explore during auto-tuning.
warmup : int, optional
Number of warmup iterations before timing.
rep : int, optional
Number of repetitions for timing measurements.
timeout : int, optional
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
......
......@@ -142,12 +142,12 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
latency: Optional[float] = None
config: Optional[dict] = None
ref_latency: Optional[float] = None
libcode: Optional[str] = None
func: Optional[Callable] = None
kernel: Optional[Callable] = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment