"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "071a9872cb76f07d09dc8a3c65046d35d921f4e6"
Commit d110d087 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] refactor autotune examples (#617)

* [Refactor] Update tilelang kernel functions and remove unused imports

- Refactored the `flashattn_fwd`, `flashattn_bwd_preprocess`, and `flashattn_bwd_postprocess` functions to utilize direct kernel calls instead of cached versions, improving clarity and performance.
- Added `@tilelang.jit` decorators with specified output indices to enhance kernel compilation.
- Removed unused import of `cached` from `tilelang`, streamlining the code.
- Commented out the main testing function call in `test_tilelang_kernel_mha_bwd.py` for potential future use.

* [Refactor] Simplify configuration generation in benchmark and example scripts

- Refactored the `get_configs` functions in multiple benchmark and example scripts to utilize a dictionary-based approach for parameter configuration, improving readability and maintainability.
- Updated the `flashattn` and `chunk_scan_fwd` functions to directly accept configuration parameters, enhancing flexibility in kernel tuning.
- Removed redundant code and streamlined the configuration generation process across various files, ensuring consistency in how configurations are defined and utilized.

* [Refactor] Update configuration handling in benchmark scripts

- Refactored the `get_configs` functions in benchmark scripts to accept a variable argument list, improving flexibility in configuration management.
- Enhanced the `matmul` and `flashattn` functions to utilize the updated configuration approach, streamlining parameter handling for kernel tuning.
- Added `@autotune` decorators to relevant functions, ensuring consistent autotuning behavior across benchmarks.
- Cleaned up redundant code and improved overall readability in the affected files.

* [Refactor] Clean up formatting and update subproject commit

- Updated the subproject commit reference in the TVM directory to indicate a dirty state.
- Removed unnecessary blank lines and improved formatting in the `benchmark_matmul` and `benchmark_matmul_fp8` scripts for better readability.
- Streamlined the function definitions in the `flashattn` example script to enhance clarity and maintainability.

* [Refactor] Update AutoTuner configuration handling

- Modified the AutoTuner class to check if kernel parameters are set before processing tunable arguments, improving robustness in configuration handling.
- Enhanced the logic for skipping compilation when tunable parameters are already provided, ensuring efficient use of resources.
- Updated comments for clarity and maintainability.

* lint fix

* Update TVM subproject commit to indicate dirty state and modify MHA backward test cases

- Updated the subproject commit reference in the TVM directory to reflect a dirty state.
- Adjusted the `test_mha_bwd` function to use a new configuration for the MHA backward tests, changing the context size from 128 to 256.
- Uncommented the main testing function call for potential execution.
parent 78056597
...@@ -29,7 +29,7 @@ def ref_program(A, B): ...@@ -29,7 +29,7 @@ def ref_program(A, B):
return A @ B.T return A @ B.T
def get_configs(M, N, K, with_roller=False): def get_configs(args, kwargs):
""" """
Generate a list of configuration dictionaries that will be used for tuning. Generate a list of configuration dictionaries that will be used for tuning.
...@@ -44,6 +44,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -44,6 +44,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages, Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning. thread numbers, and other parameters to explore during autotuning.
""" """
M, N, K, with_roller = args[:4]
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
...@@ -86,40 +88,40 @@ def get_configs(M, N, K, with_roller=False): ...@@ -86,40 +88,40 @@ def get_configs(M, N, K, with_roller=False):
for config in configs: for config in configs:
print(config) print(config)
else: else:
iter_params = dict(
block_M = [64, 128, 256] block_M=[64, 128, 256],
block_N = [64, 128, 256] block_N=[64, 128, 256],
block_K = [32, 64] block_K=[32, 64],
num_stages = [0, 1, 2, 3] num_stages=[0, 1, 2, 3],
thread_num = [128, 256] thread_num=[128, 256],
policy = [T.GemmWarpPolicy.Square] policy=[T.GemmWarpPolicy.Square],
enable_rasterization = [True, False] enable_rasteration=[True, False],
_configs = list( )
itertools.product( return [{
block_M, k: v for k, v in zip(iter_params, values)
block_N, } for values in itertools.product(*iter_params.values())]
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs return configs
def matmul(M, N, K, with_roller): @autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
""" """
Create an autotuned matrix multiplication kernel for matrices of shape: Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K) - A: (M, K)
...@@ -146,55 +148,6 @@ def matmul(M, N, K, with_roller): ...@@ -146,55 +148,6 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup). The baseline latency of the reference program (for computing speedup).
""" """
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
ref_prog=ref_program,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float16" dtype = "float16"
...@@ -218,8 +171,7 @@ def matmul(M, N, K, with_roller): ...@@ -218,8 +171,7 @@ def matmul(M, N, K, with_roller):
""" """
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -257,8 +209,6 @@ def matmul(M, N, K, with_roller): ...@@ -257,8 +209,6 @@ def matmul(M, N, K, with_roller):
return main return main
return kernel()
if __name__ == "__main__": if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions # Parse command-line arguments for matrix dimensions
......
...@@ -162,7 +162,7 @@ def ref_program(A, B): ...@@ -162,7 +162,7 @@ def ref_program(A, B):
return A @ B.T return A @ B.T
def get_configs(M, N, K, with_roller=False): def get_configs(args, kwargs):
""" """
Generate a list of configuration dictionaries that will be used for tuning. Generate a list of configuration dictionaries that will be used for tuning.
...@@ -177,6 +177,9 @@ def get_configs(M, N, K, with_roller=False): ...@@ -177,6 +177,9 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages, Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning. thread numbers, and other parameters to explore during autotuning.
""" """
M, N, K = args[:3]
with_roller = args[6]
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
...@@ -218,54 +221,38 @@ def get_configs(M, N, K, with_roller=False): ...@@ -218,54 +221,38 @@ def get_configs(M, N, K, with_roller=False):
print(config) print(config)
else: else:
block_rows_warps = [1, 2, 4] iter_params = dict(
block_col_warps = [1, 2, 4] block_row_warps=[1, 2, 4],
warp_row_tiles = [16, 32, 64, 128] block_col_warps=[1, 2, 4],
warp_col_tiles = [16, 32, 64, 128] warp_row_tiles=[16, 32, 64, 128],
chunk = [32, 64, 128, 256] warp_col_tiles=[16, 32, 64, 128],
stage = [0, 2] chunk=[32, 64, 128, 256],
enable_rasteration = [True, False] stage=[0, 2],
_configs = list( enable_rasteration=[True, False],
itertools.product(block_rows_warps, block_col_warps, warp_row_tiles, warp_col_tiles, )
chunk, stage, enable_rasteration)) return [{
configs = [{ k: v for k, v in zip(iter_params, values)
"block_row_warps": c[0], } for values in itertools.product(*iter_params.values())]
"block_col_warps": c[1],
"warp_row_tiles": c[2],
"warp_col_tiles": c[3],
"chunk": c[4],
"stage": c[5],
"enable_rasteration": c[6],
} for c in _configs]
return configs return configs
def matmul(M, @autotune(
configs=get_configs,
warmup=3,
rep=5,
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
def matmul(
M,
N, N,
K, K,
in_dtype="float16", in_dtype="float16",
out_dtype="float16", out_dtype="float16",
accum_dtype="float16", accum_dtype="float16",
with_roller=False): with_roller=False,
"""Create an autotuned tensor core matrix multiplication kernel."""
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_row_warps",
"block_col_warps",
"warp_row_tiles",
"warp_col_tiles",
"chunk",
"enable_rasteration",
"stage",
],
warmup=3,
rep=5,
)
@tl.jit(out_idx=[2],)
def kernel(
block_row_warps=None, block_row_warps=None,
block_col_warps=None, block_col_warps=None,
warp_row_tiles=None, warp_row_tiles=None,
...@@ -273,7 +260,10 @@ def matmul(M, ...@@ -273,7 +260,10 @@ def matmul(M,
chunk=None, chunk=None,
stage=None, stage=None,
enable_rasteration=None, enable_rasteration=None,
): ):
"""Create an autotuned tensor core matrix multiplication kernel."""
def kernel():
return tl_matmul( return tl_matmul(
M, M,
N, N,
......
...@@ -30,7 +30,7 @@ def ref_program(A, B): ...@@ -30,7 +30,7 @@ def ref_program(A, B):
return A.float() @ B.T.float() return A.float() @ B.T.float()
def get_configs(M, N, K, with_roller=False): def get_configs(args, kwargs):
""" """
Generate a list of configuration dictionaries that will be used for tuning. Generate a list of configuration dictionaries that will be used for tuning.
...@@ -45,6 +45,8 @@ def get_configs(M, N, K, with_roller=False): ...@@ -45,6 +45,8 @@ def get_configs(M, N, K, with_roller=False):
Each configuration dict includes various block sizes, pipeline stages, Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning. thread numbers, and other parameters to explore during autotuning.
""" """
M, N, K, with_roller = args[:4]
if with_roller: if with_roller:
from tilelang.carver.template import MatmulTemplate from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA from tilelang.carver.arch import CUDA
...@@ -87,40 +89,41 @@ def get_configs(M, N, K, with_roller=False): ...@@ -87,40 +89,41 @@ def get_configs(M, N, K, with_roller=False):
for config in configs: for config in configs:
print(config) print(config)
else: else:
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
policy = [T.GemmWarpPolicy.Square]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasteration": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs return configs
def matmul(M, N, K, with_roller): @autotune(
configs=get_configs,
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def matmul(
M,
N,
K,
with_roller,
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
""" """
Create an autotuned matrix multiplication kernel for matrices of shape: Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K) - A: (M, K)
...@@ -147,54 +150,6 @@ def matmul(M, N, K, with_roller): ...@@ -147,54 +150,6 @@ def matmul(M, N, K, with_roller):
The baseline latency of the reference program (for computing speedup). The baseline latency of the reference program (for computing speedup).
""" """
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "e4m3_float8" dtype = "e4m3_float8"
...@@ -218,8 +173,7 @@ def matmul(M, N, K, with_roller): ...@@ -218,8 +173,7 @@ def matmul(M, N, K, with_roller):
""" """
# Bind x-dimension to block index in N, # Bind x-dimension to block index in N,
# y-dimension to block index in M. # y-dimension to block index in M.
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -257,8 +211,6 @@ def matmul(M, N, K, with_roller): ...@@ -257,8 +211,6 @@ def matmul(M, N, K, with_roller):
return main return main
return kernel()
if __name__ == "__main__": if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions # Parse command-line arguments for matrix dimensions
......
...@@ -337,10 +337,6 @@ def get_best_config(N, K): ...@@ -337,10 +337,6 @@ def get_best_config(N, K):
@autotune( @autotune(
configs=get_configs(), configs=get_configs(),
keys=[
"BLOCK_N",
"reduce_threads",
],
warmup=3, warmup=3,
rep=20, rep=20,
) )
......
...@@ -223,6 +223,7 @@ def convolution(N, ...@@ -223,6 +223,7 @@ def convolution(N,
block_K, block_K,
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration,
dtype="float16", dtype="float16",
accum_dtype="float"): accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
...@@ -291,14 +292,14 @@ def main(n: int = 128, ...@@ -291,14 +292,14 @@ def main(n: int = 128,
with_roller: bool = True): with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D) ref_prog = ref_program(S, P, D)
use_autotune = True
if use_autotune: if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller) result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
print(result.config) print(result.config)
kernel = result.kernel kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2]) kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench() tilelang_latency = profiler.do_bench()
......
...@@ -57,7 +57,18 @@ def get_configs(user_config=None): ...@@ -57,7 +57,18 @@ def get_configs(user_config=None):
return valid_configs return valid_configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def flashattn(batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
...@@ -65,9 +76,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -65,9 +76,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -146,8 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -146,8 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -172,8 +179,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -172,8 +179,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -183,25 +190,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -183,25 +190,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
return main return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal, groups=1): def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D] # Q: [B, T, HQ, D]
...@@ -242,8 +230,16 @@ def main(batch: int = 1, ...@@ -242,8 +230,16 @@ def main(batch: int = 1,
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( batch,
block_M=64, block_N=64, num_stages=2, threads=128) heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -255,10 +251,10 @@ def main(batch: int = 1, ...@@ -255,10 +251,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -9,22 +9,33 @@ from functools import partial ...@@ -9,22 +9,33 @@ from functools import partial
def get_configs(): def get_configs():
block_M = [128] iter_params = dict(
block_N = [128] block_M=[128],
num_stages = [2] block_N=[128],
threads = [256] num_stages=[2],
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) threads=[256],
)
configs = [{ return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2], @autotune(
'threads': c[3] configs=get_configs(),
} for c in _configs] warmup=10,
return configs rep=10,
)
@tilelang.jit(out_idx=[3])
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): def flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups=1,
block_M=64,
block_N=64,
num_stages=0,
threads=128,
):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
...@@ -32,9 +43,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -32,9 +43,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -113,8 +121,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -113,8 +121,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
V: T.Tensor(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -144,8 +151,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -144,8 +151,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -155,25 +162,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -155,25 +162,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
return main return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal, groups=1): def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D] # Q: [B, T, HQ, D]
...@@ -216,8 +204,16 @@ def main( ...@@ -216,8 +204,16 @@ def main(
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune, groups=groups)( batch,
block_M=128, block_N=128, num_stages=2, threads=256) heads,
seq_len,
dim,
is_causal,
groups=groups,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -229,10 +225,10 @@ def main( ...@@ -229,10 +225,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -6,35 +6,31 @@ import tilelang.language as T ...@@ -6,35 +6,31 @@ import tilelang.language as T
import itertools import itertools
import argparse import argparse
from functools import partial from functools import partial
from tilelang import jit
def get_configs(): def get_configs():
block_M = [128] iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
block_N = [128] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
configs = [{ def flashattn(batch,
'block_M': c[0], heads,
'block_N': c[1], seq_q,
'num_stages': c[2], seq_kv,
'threads': c[3] dim,
} for c in _configs] is_causal,
return configs block_M=64,
block_N=64,
num_stages=1,
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -141,8 +137,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -141,8 +137,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -152,21 +148,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -152,21 +148,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
...@@ -199,8 +180,16 @@ def main( ...@@ -199,8 +180,16 @@ def main(
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( batch,
block_M=64, block_N=64, num_stages=1, threads=128) heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -213,10 +202,10 @@ def main( ...@@ -213,10 +202,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune) kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -6,35 +6,31 @@ import tilelang.language as T ...@@ -6,35 +6,31 @@ import tilelang.language as T
import itertools import itertools
import argparse import argparse
from functools import partial from functools import partial
from tilelang import jit
def get_configs(): def get_configs():
block_M = [128] iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
block_N = [128] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
configs = [{ def flashattn(batch,
'block_M': c[0], heads,
'block_N': c[1], seq_q,
'num_stages': c[2], seq_kv,
'threads': c[3] dim,
} for c in _configs] is_causal,
return configs block_M=128,
block_N=128,
num_stages=2,
def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): threads=256):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
...@@ -146,8 +142,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -146,8 +142,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -157,21 +153,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -157,21 +153,6 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
...@@ -204,8 +185,16 @@ def main( ...@@ -204,8 +185,16 @@ def main(
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune)( batch,
block_M=128, block_N=128, num_stages=2, threads=256) heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -218,10 +207,10 @@ def main( ...@@ -218,10 +207,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=tune) kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -9,30 +9,26 @@ from functools import partial ...@@ -9,30 +9,26 @@ from functools import partial
def get_configs(): def get_configs():
block_M = [64] iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128])
block_N = [64] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
num_stages = [1]
threads = [128]
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
configs = [{ def flashattn(batch,
'block_M': c[0], heads,
'block_N': c[1], seq_len,
'num_stages': c[2], dim,
'threads': c[3] is_causal,
} for c in _configs] block_M=64,
return configs block_N=64,
num_stages=1,
threads=128):
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -137,8 +132,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -137,8 +132,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -148,21 +143,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -148,21 +143,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
...@@ -193,8 +173,15 @@ def main( ...@@ -193,8 +173,15 @@ def main(
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)( batch,
block_M=128, block_N=128, num_stages=1, threads=128) heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=1,
threads=128)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
......
...@@ -9,30 +9,26 @@ from functools import partial ...@@ -9,30 +9,26 @@ from functools import partial
def get_configs(): def get_configs():
block_M = [128] iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
block_N = [128] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
configs = [{ def flashattn(batch,
'block_M': c[0], heads,
'block_N': c[1], seq_len,
'num_stages': c[2], dim,
'threads': c[3] is_causal,
} for c in _configs] block_M=128,
return configs block_N=128,
num_stages=2,
threads=256):
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
@tilelang.jit(out_idx=[3])
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -111,8 +107,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -142,8 +137,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -142,8 +137,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
stage=[-1, 0, 0, 1, -1, 1], stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
scores_sum, logsum) logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
...@@ -153,21 +148,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -153,21 +148,6 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[3])
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal): def ref_program(Q, K, V, is_causal):
dim = Q.size(-1) dim = Q.size(-1)
...@@ -198,8 +178,15 @@ def main( ...@@ -198,8 +178,15 @@ def main(
if (not tune): if (not tune):
kernel = flashattn( kernel = flashattn(
batch, heads, seq_len, dim, is_causal, tune=tune)( batch,
block_M=128, block_N=128, num_stages=2, threads=256) heads,
seq_len,
dim,
is_causal,
block_M=128,
block_N=128,
num_stages=2,
threads=256)
ref_program_processed = partial(ref_program, is_causal=is_causal) ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
...@@ -211,10 +198,10 @@ def main( ...@@ -211,10 +198,10 @@ def main(
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, seq_len, dim, is_causal, tune=tune) kernel = flashattn(batch, heads, seq_len, dim, is_causal)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -54,7 +54,10 @@ def get_pass_configs(): ...@@ -54,7 +54,10 @@ def get_pass_configs():
return {} return {}
def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages,
threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim] shape_k = [batch, seqlen_kv, groups, dim]
...@@ -64,8 +67,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -64,8 +67,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // groups kv_group_num = heads // groups
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def kernel_func(block_N, block_H, num_split, num_stages, threads):
part_shape = [batch, heads, num_split, dim] part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
valid_block_N = min(block_N, seqlen_kv // num_split) valid_block_N = min(block_N, seqlen_kv // num_split)
...@@ -78,8 +79,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -78,8 +79,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -108,12 +108,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -108,12 +108,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
-T.infinity(accum_dtype)) -T.infinity(accum_dtype))
...@@ -148,8 +143,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -148,8 +143,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -179,23 +173,18 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -179,23 +173,18 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
(k + 1) * valid_block_N, cur_kv_head, :], K_shared) cur_kv_head, :], K_shared)
T.copy( T.copy(
mask[bid, (seqlen_kv // num_split) * sid + mask[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
(k + 1) * valid_block_N, cur_kv_head], mask_local) cur_kv_head], mask_local)
T.clear(acc_s) T.clear(acc_s)
T.gemm( T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else( acc_s[i,
(mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
-T.infinity(accum_dtype)) acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -211,8 +200,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -211,8 +200,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V[bid, (seqlen_kv // num_split) * sid + V[bid, (seqlen_kv // num_split) * sid +
k * valid_block_N:(seqlen_kv // num_split) * sid + k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
(k + 1) * valid_block_N, cur_kv_head, :], V_shared) cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
...@@ -242,13 +231,10 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -242,13 +231,10 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({ T.annotate_layout({
lse_logsum_local: lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
lse_max_local:
T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id) # lse_local: (local_id, thread_id)
lse_local: lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
}) })
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
...@@ -300,25 +286,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -300,25 +286,6 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
else: else:
return flashattn_gqa_decode_no_split return flashattn_gqa_decode_no_split
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(
out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
max_mismatched_ratio=0.05)
def kernel(block_N=None, block_H=None, num_split=None, num_stages=None, threads=None):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel()
else:
def kernel(block_N, block_H, num_split, num_stages, threads):
return kernel_func(block_N, block_H, num_split, num_stages, threads)
return kernel
def ref_program(query, key, value, mask, glse, Output_partial): def ref_program(query, key, value, mask, glse, Output_partial):
# """ # """
...@@ -485,7 +452,7 @@ def main(batch: int = 1, ...@@ -485,7 +452,7 @@ def main(batch: int = 1,
if (not tune): if (not tune):
config, sm_version = get_heuristic_config() config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune)(**config) kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16) q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
...@@ -513,10 +480,10 @@ def main(batch: int = 1, ...@@ -513,10 +480,10 @@ def main(batch: int = 1,
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = flashattn(batch, heads, groups, kv_seqlen, dim, tune=tune) kernel = flashattn(batch, heads, groups, kv_seqlen, dim)
best_latency = best_result.latency best_latency = kernel.latency
best_config = best_result.config best_config = kernel.config
ref_latency = best_result.ref_latency ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
......
...@@ -219,17 +219,10 @@ def splitk_gemv_vectorized_tvm( ...@@ -219,17 +219,10 @@ def splitk_gemv_vectorized_tvm(
def get_best_config(N, K): def get_best_config(N, K):
def get_configs(): def get_configs():
BLOCK_N = [2, 4, 8, 32, 64, 128] iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
reduce_threads = [4, 8, 32] return [
_configs = list(itertools.product( dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())
BLOCK_N, ]
reduce_threads,
))
configs = [{
"BLOCK_N": c[0],
"reduce_threads": c[1],
} for c in _configs]
return configs
@autotune( @autotune(
configs=get_configs(), configs=get_configs(),
......
...@@ -61,42 +61,43 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): ...@@ -61,42 +61,43 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
def get_configs(): def get_configs():
block_M = [64, 128, 256] iter_params = dict(
block_N = [32, 64] block_M=[64, 128, 256],
block_K = [64, 128, 256] block_N=[32, 64],
block_Dstate = [128] block_K=[64, 128, 256],
num_stages = [1, 2, 3, 4, 5] block_Dstate=[128],
_configs = list(itertools.product(block_M, block_N, block_K, block_Dstate, num_stages)) num_stages=[1, 2, 3, 4, 5])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'block_Dstate': c[3],
'num_stages': c[4],
'threads': c[0] * 2
} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7]) @tilelang.jit(out_idx=[7])
def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): def chunk_scan_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504 p = 1.44269504
def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):
@T.prim_func @T.prim_func
def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor(
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor( (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), C: T.Tensor( (batch, nheads, nchunks, chunk_size), dtype),
(batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor( C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor( (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Tensor( (nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)):
(batch, seqlen, nheads, headdim), dtype)):
with T.Kernel( with T.Kernel(
nheads, nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
...@@ -162,8 +163,9 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -162,8 +163,9 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
dA_cs_k_shared) dA_cs_k_shared)
T.copy(dA_cs_k_shared, dA_cs_k_local) T.copy(dA_cs_k_shared, dA_cs_k_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - cb_local[i,
dA_cs_k_local[j] * p) j] = cb_local[i,
j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p)
T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
T.copy(dt_shared, dt_local) T.copy(dt_shared, dt_local)
for i, j in T.Parallel(block_M, block_K): for i, j in T.Parallel(block_M, block_K):
...@@ -188,32 +190,11 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -188,32 +190,11 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
T.copy(acc_o, acc_o_shared) T.copy(acc_o, acc_o_shared)
T.copy( T.copy(
acc_o_shared, acc_o_shared,
Output[batch_idx, chunk_idx * chunk_size + Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size +
m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N])
n_idx * block_N:(n_idx + 1) * block_N])
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[7])
def kernel(block_M=None,
block_N=None,
block_K=None,
block_Dstate=None,
num_stages=None,
threads=None):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, block_Dstate, num_stages, threads):
return kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads)
return kernel
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -231,8 +212,19 @@ if __name__ == "__main__": ...@@ -231,8 +212,19 @@ if __name__ == "__main__":
if (not args.tune): if (not args.tune):
kernel = chunk_scan_fwd( kernel = chunk_scan_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( batch,
block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128) seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=64,
block_K=64,
block_Dstate=128,
num_stages=2,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
...@@ -243,11 +235,10 @@ if __name__ == "__main__": ...@@ -243,11 +235,10 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = chunk_scan_fwd( kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune) best_latency = kernel.latency
best_latency = best_result.latency best_config = kernel.config
best_config = best_result.config ref_latency = kernel.ref_latency
ref_latency = best_result.ref_latency
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
...@@ -46,31 +46,30 @@ def ref_program(B, x, dt, dA_cumsum): ...@@ -46,31 +46,30 @@ def ref_program(B, x, dt, dA_cumsum):
def get_configs(): def get_configs():
block_M = [64, 128] iter_params = dict(
block_N = [32, 64, 128] block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
block_K = [32, 64] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
num_stages = [1, 2, 3, 4, 5]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[0] * 2
} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4]) @tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, tune=False): def chunk_state_fwd(batch,
seqlen,
chunk_size,
ngroups,
nheads,
headdim,
dstate,
block_M=64,
block_N=64,
block_K=64,
num_stages=2,
threads=128):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504 p = 1.44269504
def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func @T.prim_func
def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor( def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Tensor( (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
...@@ -136,21 +135,6 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -136,21 +135,6 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
return main return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -168,8 +152,18 @@ if __name__ == "__main__": ...@@ -168,8 +152,18 @@ if __name__ == "__main__":
if (not args.tune): if (not args.tune):
kernel = chunk_state_fwd( kernel = chunk_state_fwd(
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)( batch,
block_M=64, block_N=128, block_K=64, num_stages=4, threads=128) seq_len,
chunk_size,
groups,
heads,
dim,
dstate,
block_M=64,
block_N=128,
block_K=64,
num_stages=4,
threads=128)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
...@@ -180,8 +174,7 @@ if __name__ == "__main__": ...@@ -180,8 +174,7 @@ if __name__ == "__main__":
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = chunk_state_fwd( best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
batch, seq_len, chunk_size, groups, heads, dim, dstate, tune=args.tune)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
ref_latency = best_result.ref_latency ref_latency = best_result.ref_latency
......
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tilelang import tilelang
from tilelang import cached
import tilelang.language as T import tilelang.language as T
import tilelang.testing import tilelang.testing
...@@ -9,6 +8,7 @@ import tilelang.testing ...@@ -9,6 +8,7 @@ import tilelang.testing
tilelang.testing.set_random_seed(42) tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[3, 4],)
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
...@@ -78,6 +78,7 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -78,6 +78,7 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
return flash_fwd return flash_fwd
@tilelang.jit(out_idx=[2],)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -113,6 +114,7 @@ def make_dq_layout(dQ): ...@@ -113,6 +114,7 @@ def make_dq_layout(dQ):
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1],)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -134,11 +136,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -134,11 +136,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit( @tilelang.jit(out_idx=[7, 8])
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
...@@ -161,10 +159,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -161,10 +159,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype) dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype) q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype) V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -182,7 +176,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -182,7 +176,6 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
T.annotate_layout({ T.annotate_layout({
dQ: make_dq_layout(dQ), dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}) })
...@@ -237,8 +230,8 @@ class _attention(torch.autograd.Function): ...@@ -237,8 +230,8 @@ class _attention(torch.autograd.Function):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
mod = cached(flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N), [3, 4]) kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = mod(q, k, v) o, lse = kernel(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
return o return o
...@@ -256,13 +249,13 @@ class _attention(torch.autograd.Function): ...@@ -256,13 +249,13 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128 block_M = 128
block_N = 128 if D_HEAD <= 64 else 32 block_N = 128 if D_HEAD <= 64 else 32
mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2]) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1]) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do) kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) delta = kernel_prep(o, do)
dq = torch.zeros_like(q, dtype=torch.float32) dq = torch.zeros_like(q, dtype=torch.float32)
dk, dv = mod(q, k, v, do, lse, delta, dq) dk, dv = kernel(q, k, v, do, lse, delta, dq)
dq = mod_post(dq) dq = kernel_post(dq)
return dq, dk, dv, None return dq, dk, dv, None
...@@ -307,8 +300,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal): ...@@ -307,8 +300,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd(): def test_mha_bwd():
assert_mha_equal(8, 32, 128, 64, False) assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 128, 64, True) assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -263,6 +263,9 @@ class AutoTuner: ...@@ -263,6 +263,9 @@ class AutoTuner:
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
parameters = sig.parameters parameters = sig.parameters
if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters) key = self.generate_cache_key(parameters)
with self._lock: with self._lock:
...@@ -392,6 +395,31 @@ class AutoTuner: ...@@ -392,6 +395,31 @@ class AutoTuner:
raise ValueError(f"Unused keys in config: {unused_keys}") raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs) config_args.append(new_kwargs)
if len(config_args) == 0:
raise ValueError("No configurations to tune, please check your `@autotune` decorator")
# check if the tunable arguments has been set.
# get the back config argument
top_config, *rest = config_args
if self._kernel_parameters is not None:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
num_workers = max(1, int(get_available_cpu_count() * 0.9)) num_workers = max(1, int(get_available_cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = [] futures = []
...@@ -502,7 +530,7 @@ class _AutoTunerImplementation: ...@@ -502,7 +530,7 @@ class _AutoTunerImplementation:
warmup: int = 25 warmup: int = 25
rep: int = 100 rep: int = 100
timeout: int = 100 timeout: int = 100
configs: Any = None configs: Union[Dict, Callable] = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None ref_prog: Callable = None
supply_prog: Callable = None supply_prog: Callable = None
...@@ -514,7 +542,7 @@ class _AutoTunerImplementation: ...@@ -514,7 +542,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False cache_input_tensors: bool = False
def __init__(self, def __init__(self,
configs: Any, configs: Union[Dict, Callable],
warmup: int = 25, warmup: int = 25,
rep: int = 100, rep: int = 100,
timeout: int = 100, timeout: int = 100,
...@@ -581,7 +609,6 @@ class _AutoTunerImplementation: ...@@ -581,7 +609,6 @@ class _AutoTunerImplementation:
warmup = self.warmup warmup = self.warmup
rep = self.rep rep = self.rep
timeout = self.timeout timeout = self.timeout
configs = self.configs
@functools.wraps(fn) @functools.wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -598,7 +625,7 @@ class _AutoTunerImplementation: ...@@ -598,7 +625,7 @@ class _AutoTunerImplementation:
compile_arguments = fn(__return_compile_arguments=True) compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner( autotuner = AutoTuner(
fn, configs=configs).set_profile_args( fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type, supply_type=self.supply_type,
ref_prog=self.ref_prog, ref_prog=self.ref_prog,
supply_prog=self.supply_prog, supply_prog=self.supply_prog,
...@@ -634,7 +661,7 @@ class _AutoTunerImplementation: ...@@ -634,7 +661,7 @@ class _AutoTunerImplementation:
def autotune( # This is the new public interface def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None, func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only *, # Indicates subsequent arguments are keyword-only
configs: Any, configs: Union[Dict, Callable],
# profile arguments # profile arguments
warmup: int = 25, warmup: int = 25,
rep: int = 100, rep: int = 100,
...@@ -656,12 +683,29 @@ def autotune( # This is the new public interface ...@@ -656,12 +683,29 @@ def autotune( # This is the new public interface
This decorator can be used without arguments (e.g., `@tilelang.jit`): This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings. Applies JIT compilation with default settings.
Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python
if enable_autotune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
else:
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
```
Parameters Parameters
---------- ----------
func_or_out_idx : Any, optional func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`). the function to be decorated (and `out_idx` will be `None`).
configs : Dict or Callable
Configuration space to explore during auto-tuning.
warmup : int, optional
Number of warmup iterations before timing.
rep : int, optional
Number of repetitions for timing measurements.
timeout : int, optional
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional target_host : Union[str, Target], optional
......
...@@ -142,12 +142,12 @@ class AutotuneResult: ...@@ -142,12 +142,12 @@ class AutotuneResult:
func: Optimized function. func: Optimized function.
kernel: Compiled kernel function. kernel: Compiled kernel function.
""" """
latency: float latency: Optional[float] = None
config: dict config: Optional[dict] = None
ref_latency: float ref_latency: Optional[float] = None
libcode: str libcode: Optional[str] = None
func: Callable func: Optional[Callable] = None
kernel: Callable kernel: Optional[Callable] = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel): def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment