"examples/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "53be59dcc072c78730a83f848154357286c63ccd"
Commit bfb5b0a3 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Bugfix] fix the unexpected keyword error of autotune (#438)

* yes

* [Bugfix] fix the unexpected keyword error of autotune

* format

* test
parent 6c737768
import argparse
import itertools
import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
# copied from https://github.com/tile-ai/tilelang/blob/main/testing/python/kernel/test_tilelang_kernel_element_wise_add.py def ref_program(x, y):
def elementwise_add( return x + y
M,
N,
block_M, def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
block_N,
in_dtype,
out_dtype,
threads,
):
@T.prim_func @T.prim_func
def main( def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
A: T.Tensor((M, N), in_dtype), out_dtype)):
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N start_x = bx * block_N
start_y = by * block_M start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N): for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y y = start_y + local_y
x = start_x + local_x x = start_x + local_x
C[y, x] = A[y, x] + B[y, x] C[y, x] = A[y, x] + B[y, x]
return main return main
def ref_program(x, y): def get_configs(M, N):
return x + y block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N)).set_compile_args(
out_idx=[-1],
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
target="cuda",
)
return autotuner.run(warmup=3, rep=20)
if __name__ == "__main__": if __name__ == "__main__":
program = elementwise_add(512, 1024, 128, 256, "float32", "float32", 128) parser = argparse.ArgumentParser()
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") parser.add_argument("--m", type=int, default=512)
profiler = kernel.get_profiler() parser.add_argument("--n", type=int, default=1024)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) parser.add_argument("--use_autotune", action="store_true", default=False)
print("All checks pass.") args = parser.parse_args()
latency = profiler.do_bench(ref_program, warmup=500) M, N = args.m, args.n
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench(warmup=500) a = torch.randn(M, N, dtype=torch.float32, device="cuda")
print("Tile-lang: {:.2f} ms".format(latency)) b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
else:
# Default config
config = {"block_M": 128, "block_N": 256, "threads": 128}
kernel = tilelang.compile(
elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
...@@ -152,11 +152,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -152,11 +152,7 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -9,10 +9,10 @@ from functools import partial ...@@ -9,10 +9,10 @@ from functools import partial
def get_configs(): def get_configs():
block_M = [128] block_M = [64]
block_N = [128] block_N = [64]
num_stages = [2] num_stages = [1]
threads = [256] threads = [128]
_configs = list(itertools.product(block_M, block_N, num_stages, threads)) _configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{ configs = [{
...@@ -149,11 +149,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -149,11 +149,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -154,11 +154,7 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None) @jit(out_idx=[3], supply_type=tilelang.TensorSupplyType.Integer, ref_prog=None)
def kernel(block_M=None, block_N=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads) return kernel_func(block_M, block_N, num_stages, threads)
......
...@@ -271,11 +271,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -271,11 +271,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_N", "block_H", "num_split", "num_stages", "threads"],
warmup=10,
rep=10)
@jit( @jit(
out_idx=[6], out_idx=[6],
supply_type=tilelang.TensorSupplyType.Auto, supply_type=tilelang.TensorSupplyType.Auto,
......
...@@ -44,8 +44,6 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -44,8 +44,6 @@ def get_configs(M, N, K, with_roller=False, topk=20):
config["thread_num"] = block_rows * block_cols * 32 config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config) configs.append(config)
for config in configs:
print(config)
else: else:
block_M = [64, 128, 256] block_M = [64, 128, 256]
block_N = [64, 128, 256] block_N = [64, 128, 256]
......
...@@ -195,11 +195,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -195,11 +195,7 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "block_Dstate", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None) @jit(out_idx=[7], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None, def kernel(block_M=None,
block_N=None, block_N=None,
......
...@@ -137,11 +137,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -137,11 +137,7 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
if tune: if tune:
@autotune( @autotune(configs=get_configs(), warmup=10, rep=10)
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None) @jit(out_idx=[4], supply_type=tilelang.TensorSupplyType.Normal, ref_prog=None)
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads) return kernel_func(block_M, block_N, block_K, num_stages, threads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment