Commit bfb5b0a3 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Bugfix] fix the unexpected keyword error of autotune (#438)

* yes

* [Bugfix] fix the unexpected keyword error of autotune

* format

* test
parent 6c737768
import argparse
import itertools
import torch
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner
# copied from https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py
def elementwise_add(
M,
N,
block_M,
block_N,
in_dtype,
out_dtype,
threads,
):
def ref_program(x, y):
return x + y
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def main(
A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]
return main
def ref_program(x, y):
return x + y
def get_configs(M, N):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N)).set_compile_args(
out_idx=[-1],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
target="cuda",
)
return autotuner.run(warmup=3, rep=20)
if __name__ == "__main__":
program = elementwise_add(512, 1024, 128, 256, "float32", "float32", 128)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args = parser.parse_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
else:
# Default config
config = {"block_M": 128, "block_N": 256, "threads": 128}
kernel = tilelang.compile(
elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......@@ -152,11 +152,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -9,10 +9,10 @@ from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
block_M = [64]
block_N = [64]
num_stages = [1]
threads = [128]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
......@@ -149,11 +149,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
......
......@@ -271,11 +271,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
if tune:
@autotune(
configs=get_configs(),
keys=["block_N", "block_H", "num_split", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(
out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto,
......
......@@ -44,8 +44,6 @@ def get_configs(M, N, K, with_roller=False, topk=20):
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
......
......@@ -195,11 +195,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None,
block_N=None,
......
......@@ -137,11 +137,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment