"docs/en_US/vscode:/vscode.git/clone" did not exist on "4bbffd170b9c5fd660ebea4c837867df90aaaabf"
Commit d110d087 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.
parent 78056597
......@@ -29,7 +29,7 @@ def ref_program(A, B):
return A @ B.T
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -44,6 +44,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K, with_roller = args[:4]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -86,40 +88,40 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[32, 64],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs
def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -146,55 +148,6 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
ref_prog=ref_program,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
......@@ -218,8 +171,7 @@ def matmul(M, N, K, with_roller):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -257,8 +209,6 @@ def matmul(M, N, K, with_roller):
return main
return kernel()
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
......
......@@ -162,7 +162,7 @@ def ref_program(A, B):
return A @ B.T
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -177,6 +177,9 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K = args[:3]
with_roller = args[6]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -218,54 +221,38 @@ def get_configs(M, N, K, with_roller=False):
print(config)
else:
block_rows_warps = [1, 2, 4]
block_col_warps = [1, 2, 4]
warp_row_tiles = [16, 32, 64, 128]
warp_col_tiles = [16, 32, 64, 128]
chunk = [32, 64, 128, 256]
stage = [0, 2]
enable_rasteration = [True, False]
_configs = list(
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles,
chunk, stage, enable_rasteration))
configs = [{
"block_row_warps": c[0],
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
warp_row_tiles=[16, 32, 64, 128],
warp_col_tiles=[16, 32, 64, 128],
chunk=[32, 64, 128, 256],
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return configs
def matmul(M,
@autotune(
configs=get_configs,
warmup=3,
rep=5,
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
with_roller=False):
"""Create an autotuned tensor core matrix multiplication kernel."""
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@tl.jit(out_idx=[2],)
def kernel(
with_roller=False,
block_row_warps=None,
block_col_warps=None,
warp_row_tiles=None,
......@@ -273,7 +260,10 @@ def matmul(M,
chunk=None,
stage=None,
enable_rasteration=None,
):
):
"""Create an autotuned tensor core matrix multiplication kernel."""
def kernel():
return tl_matmul(
M,
N,
......
......@@ -30,7 +30,7 @@ def ref_program(A, B):
return A.float() @ B.T.float()
def get_configs(M, N, K, with_roller=False):
def get_configs(args, kwargs):
"""
Generate a list of configuration dictionaries that will be used for tuning.
......@@ -45,6 +45,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
M, N, K, with_roller = args[:4]
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
......@@ -87,40 +89,41 @@ def get_configs(M, N, K, with_roller=False):
for config in configs:
print(config)
else:
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
@autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -147,54 +150,6 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "e4m3_float8"
......@@ -218,8 +173,7 @@ def matmul(M, N, K, with_roller):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
......@@ -257,8 +211,6 @@ def matmul(M, N, K, with_roller):
return main
return kernel()
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
......
......@@ -337,10 +337,6 @@ def get_best_config(N, K):
@autotune(
configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3,
rep=20,
)
......
......@@ -223,6 +223,7 @@ def convolution(N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K
......@@ -291,14 +292,14 @@ def main(n: int = 128,
with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
use_autotune = True
if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2])
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
......
......@@ -57,7 +57,18 @@ def get_configs(user_config=None):
return valid_configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
......@@ -65,9 +76,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
......@@ -146,8 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -172,8 +179,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -183,25 +190,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
......@@ -242,8 +230,16 @@ def main(batch: int = 1,
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=64, block_N=64, num_stages=2, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -255,10 +251,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -9,22 +9,33 @@ from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
iter_params = dict(
block_M=[128],
block_N=[128],
num_stages=[2],
threads=[256],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(),
warmup=10,
rep=10,
)
@tilelang.jit(out_idx=[3])
def flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128,
):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
......@@ -32,9 +43,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
......@@ -113,8 +121,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -144,8 +151,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -155,25 +162,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
......@@ -216,8 +204,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -229,10 +225,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -6,35 +6,31 @@ import tilelang.language as T
import itertools
import argparse
from functools import partial
from tilelang import jit
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
......@@ -141,8 +137,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -152,21 +148,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
......@@ -199,8 +180,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=64, block_N=64, num_stages=1, threads=128)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
......@@ -213,10 +202,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -6,35 +6,31 @@ import tilelang.language as T
import itertools
import argparse
from functools import partial
from tilelang import jit
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
......@@ -146,8 +142,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -157,21 +153,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
......@@ -204,8 +185,16 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
......@@ -218,10 +207,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -9,30 +9,26 @@ from functools import partial
def get_configs():
block_M = [64]
block_N = [64]
num_stages = [1]
threads = [128]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -137,8 +132,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -148,21 +143,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
......@@ -193,8 +173,15 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=1, threads=128)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......
......@@ -9,30 +9,26 @@ from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -142,8 +137,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
......@@ -153,21 +148,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
......@@ -198,8 +178,15 @@ def main(
if (not tune):
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)(
block_M=128, block_N=128, num_stages=2, threads=256)
batch,
heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......@@ -211,10 +198,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -54,7 +54,10 @@ def get_pass_configs():
return {}
def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages,
threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim]
......@@ -64,8 +67,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
accum_dtype = "float"
kv_group_num = heads // groups
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def kernel_func(block_N, block_H, num_split, num_stages, threads):
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split)
......@@ -78,8 +79,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -108,12 +108,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype))
......@@ -148,8 +143,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -179,23 +173,18 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head, :], K_shared)
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], K_shared)
T.copy(
mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head], mask_local)
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head], mask_local)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(
(mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j],
-T.infinity(accum_dtype))
acc_s[i,
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -211,8 +200,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
acc_o[i, j] *= scores_scale[i]
T.copy(
V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid +
(k + 1) * valid_block_N, cur_kv_head, :], V_shared)
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
......@@ -242,13 +231,10 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local:
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local:
T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id)
lse_local:
T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
})
T.clear(lse_logsum_local)
......@@ -300,25 +286,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
else:
return flashattn_gqa_decode_no_split
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(
out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
max_mismatched_ratio=0.05)
def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel()
else:
def kernel(block_N, block_H, num_split, num_stages, threads):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel
def ref_program(query, key, value, mask, glse, Output_partial):
# """
......@@ -485,7 +452,7 @@ def main(batch: int = 1,
if (not tune):
config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config)
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
......@@ -513,10 +480,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......
......@@ -219,17 +219,10 @@ def splitk_gemv_vectorized_tvm(
def get_best_config(N, K):
def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128]
reduce_threads = [4, 8, 32]
_configs = list(itertools.product(
BLOCK_N,
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
return [
dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())
]
@autotune(
configs=get_configs(),
......
......@@ -61,42 +61,43 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def get_configs():
block_M = [64, 128, 256]
block_N = [32, 64]
block_K = [64, 128, 256]
block_Dstate = [128]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'block_Dstate': c[3],
'num_stages': c[4],
'threads': c[0] * 2
} for c in _configs]
return configs
iter_params = dict(
block_M=[64, 128, 256],
block_N=[32, 64],
block_K=[64, 128, 256],
block_Dstate=[128],
num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
def chunk_scan_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):
@T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), C: T.Tensor(
(batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype),
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Tensor(
(batch, seqlen, nheads, headdim), dtype)):
(nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)):
with T.Kernel(
nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
......@@ -162,8 +163,9 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p -
dA_cs_k_local[j] * p)
cb_local[i,
j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K):
......@@ -188,32 +190,11 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
T.copy(acc_o, acc_o_shared)
T.copy(
acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size +
m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz,
n_idx * block_N:(n_idx + 1) * block_N])
Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
(m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
def kernel(block_M=None,
block_N=None,
block_K=None,
block_Dstate=None,
num_stages=None,
threads=None):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -231,8 +212,19 @@ if __name__ == "__main__":
if (not args.tune):
kernel = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128)
batch,
seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -243,11 +235,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
......@@ -46,31 +46,30 @@ def ref_program(B, x, dt, dA_cumsum):
def get_configs():
block_M = [64, 128]
block_N = [32, 64, 128]
block_K = [32, 64]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[0] * 2
} for c in _configs]
return configs
iter_params = dict(
block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False):
def chunk_state_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128):
dtype = "float16"
accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
......@@ -136,21 +135,6 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -168,8 +152,18 @@ if __name__ == "__main__":
if (not args.tune):
kernel = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)(
block_M=64, block_N=128, block_K=64, num_stages=4, threads=128)
batch,
seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=128,
block_K=64,
num_stages=4,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -180,8 +174,7 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
......
import torch
import torch.nn.functional as F
import tilelang
from tilelang import cached
import tilelang.language as T
import tilelang.testing
......@@ -9,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[3, 4],)
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
......@@ -78,6 +78,7 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
return flash_fwd
@tilelang.jit(out_idx=[2],)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -113,6 +114,7 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1],)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
......@@ -134,11 +136,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post
@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
@tilelang.jit(out_idx=[7, 8])
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......@@ -161,10 +159,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
......@@ -182,7 +176,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
......@@ -237,8 +230,8 @@ class _attention(torch.autograd.Function):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
mod = cached(flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N), [3, 4])
o, lse = mod(q, k, v)
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = kernel(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
......@@ -256,13 +249,13 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2])
mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1])
delta = mod_prep(o, do)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
delta = kernel_prep(o, do)
dq = torch.zeros_like(q, dtype=torch.float32)
dk, dv = mod(q, k, v, do, lse, delta, dq)
dq = mod_post(dq)
dk, dv = kernel(q, k, v, do, lse, delta, dq)
dq = kernel_post(dq)
return dq, dk, dv, None
......@@ -307,8 +300,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd():
assert_mha_equal(8, 32, 128, 64, False)
assert_mha_equal(8, 32, 128, 64, True)
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__":
......
......@@ -263,6 +263,9 @@ class AutoTuner:
sig = inspect.signature(self.fn)
parameters = sig.parameters
if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters)
with self._lock:
......@@ -392,6 +395,31 @@ class AutoTuner:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs)
if len(config_args) == 0:
raise ValueError("No configurations to tune, please check your `@autotune` decorator")
# check if the tunable arguments has been set.
# get the back config argument
top_config, *rest = config_args
if self._kernel_parameters is not None:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
num_workers = max(1, int(get_available_cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
......@@ -502,7 +530,7 @@ class _AutoTunerImplementation:
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Any = None
configs: Union[Dict, Callable] = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
......@@ -514,7 +542,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False
def __init__(self,
configs: Any,
configs: Union[Dict, Callable],
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
......@@ -581,7 +609,6 @@ class _AutoTunerImplementation:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
configs = self.configs
@functools.wraps(fn)
def wrapper(*args, **kwargs):
......@@ -598,7 +625,7 @@ class _AutoTunerImplementation:
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
......@@ -634,7 +661,7 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
configs: Any,
configs: Union[Dict, Callable],
# profile arguments
warmup: int = 25,
rep: int = 100,
......@@ -656,12 +683,29 @@ def autotune( # This is the new public interface
This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python
if enable_autotune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
else:
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
```
Parameters
----------
func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
configs : Dict or Callable
Configuration space to explore during auto-tuning.
warmup : int, optional
Number of warmup iterations before timing.
rep : int, optional
Number of repetitions for timing measurements.
timeout : int, optional
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
......
......@@ -142,12 +142,12 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
latency: Optional[float] = None
config: Optional[dict] = None
ref_latency: Optional[float] = None
libcode: Optional[str] = None
func: Optional[Callable] = None
kernel: Optional[Callable] = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment