Commit dd7eb488 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[CI] Add elementwise and gemv examples to CI. (#458)

* [CI] Add elementwise and gemv examples to CI.

* fix lint

* test

* fix gemv lint

* fix lint
parent 8d5e803e
...@@ -13,7 +13,7 @@ def ref_program(x, y): ...@@ -13,7 +13,7 @@ def ref_program(x, y):
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
out_dtype)): out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N start_x = bx * block_N
...@@ -23,7 +23,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -23,7 +23,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
x = start_x + local_x x = start_x + local_x
C[y, x] = A[y, x] + B[y, x] C[y, x] = A[y, x] + B[y, x]
return main return elem_add
def get_configs(M, N): def get_configs(M, N):
...@@ -49,13 +49,12 @@ def get_best_config(M, N): ...@@ -49,13 +49,12 @@ def get_best_config(M, N):
) )
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512) parser.add_argument("--m", type=int, default=512)
parser.add_argument("--n", type=int, default=1024) parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False) parser.add_argument("--use_autotune", action="store_true", default=False)
args = parser.parse_args() args, _ = parser.parse_known_args()
M, N = args.m, args.n M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda") a = torch.randn(M, N, dtype=torch.float32, device="cuda")
...@@ -72,3 +71,7 @@ if __name__ == "__main__": ...@@ -72,3 +71,7 @@ if __name__ == "__main__":
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang.testing
import example_elementwise_add
def test_example_elementwise_add():
example_elementwise_add.main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -302,11 +302,11 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True): ...@@ -302,11 +302,11 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print(f"TileLang Latency: {latency} ms\n") print(f"TileLang Latency: {latency} ms\n")
if __name__ == "__main__": def main():
parser = argparse.ArgumentParser(description="GEMV Example") parser = argparse.ArgumentParser(description="GEMV Example")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
args = parser.parse_args() args, _ = parser.parse_known_args()
N, K = args.n, args.k N, K = args.n, args.k
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K) check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K) check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K)
...@@ -316,9 +316,7 @@ if __name__ == "__main__": ...@@ -316,9 +316,7 @@ if __name__ == "__main__":
print("Test passed!") print("Test passed!")
best_result = get_best_config(N, K) best_result = get_best_config(N, K)
best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = splitk_gemv_vectorized_tvm(N, K, *best_config) kernel = splitk_gemv_vectorized_tvm(N, K, *best_config)
kernel = tl.compile(kernel, out_idx=-1) kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -326,3 +324,7 @@ if __name__ == "__main__": ...@@ -326,3 +324,7 @@ if __name__ == "__main__":
print(f"Torch Latency: {latency} ms") print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=500) latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang Latency: {latency} ms\n") print(f"TileLang Latency: {latency} ms\n")
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemv
def test_example_gemv():
example_gemv.main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment