Unverified Commit 5c62d00a authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[Testing] Move TMA 1D and test for its functionality (#1167)

* [Testing] Move TMA 1D and test for its functionality

* [Lint]
parent 54d4bd62
import tilelang.testing
import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add():
example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__":
tilelang.testing.main()
import argparse
import torch
import tilelang
import tilelang.language as T
import torch
def ref_program(x, y):
......@@ -30,23 +29,29 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return elem_add
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
def run_elementwise_add(M, N):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
block_M, block_N = 128, 128
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")
code = kernel.get_kernel_source()
if block_N == N:
assert "tma_load" in code and "CUtensorMap" not in code
else:
assert "tma_load" in code and "CUtensorMap" in code
def main():
run_elementwise_add(128, 128)
run_elementwise_add(256, 128)
run_elementwise_add(256, 256)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment