"git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "c3efeb5e20b91a0577f056f3839dfc30885a715e"
Unverified Commit f4f87f46 authored by senlyu163's avatar senlyu163 Committed by GitHub
Browse files

[Bugfix] Improve autotune from elementwise_add function in examples (#1445)

* Remove JIT decorator from elementwise_add function in examples

* fix kernel compilation without autotune

* Refactor main function to accept parameters and update tests for autotune option

* Refactor autotune test function for morden style
parent 9c21586b
...@@ -3,13 +3,21 @@ import itertools ...@@ -3,13 +3,21 @@ import itertools
import torch import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
def ref_program(x, y): def ref_program(x, y):
return x + y return x + y
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
...@@ -30,47 +38,12 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -30,47 +38,12 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return elem_add return elem_add
def get_configs(M, N): def main(M=1024, N=1024, use_autotune=False):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = (
AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N))
.set_compile_args(
out_idx=[-1],
target="cuda",
)
.set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
)
)
return autotuner.run(warmup=3, rep=20)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda") a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune: if use_autotune:
result = get_best_config(M, N) kernel = elementwise_add(M, N, in_dtype="float32", out_dtype="float32")
kernel = result.kernel
else: else:
# Default config # Default config
config = {"block_M": 32, "block_N": 32, "threads": 128} config = {"block_M": 32, "block_N": 32, "threads": 128}
...@@ -81,4 +54,9 @@ def main(): ...@@ -81,4 +54,9 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
main(args.m, args.n, args.use_autotune)
...@@ -6,5 +6,9 @@ def test_example_elementwise_add(): ...@@ -6,5 +6,9 @@ def test_example_elementwise_add():
example_elementwise_add.main() example_elementwise_add.main()
def test_example_elementwise_add_autotune():
example_elementwise_add.main(use_autotune=True)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment