Commit ee4e708d authored by Leslin's avatar Leslin Committed by LeiWang1999
Browse files

[CI] Add gemm and gemm_fp8 example to CI (#516)

* [CI] Add gemm and gemm_fp8 example to CI

* Fix lint via format.sh

* Resolved the issues with profiler API usage and parse_args
parent 41bc15cb
......@@ -5,7 +5,7 @@ import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
......@@ -23,32 +23,37 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return gemm
func = matmul(1024, 1024, 1024, 128, 128, 32)
def main():
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
kernel = tilelang.compile(func, out_idx=-1)
import torch
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
c = kernel(a, b)
ref_c = a @ b
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
print("c:")
print(c)
print("ref_c:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
\ No newline at end of file
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
if __name__ == "__main__":
main()
......@@ -176,7 +176,7 @@ def matmul(M,
accum_dtype="float"):
@T.prim_func
def main(
def gemm_autotune(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
......@@ -200,31 +200,16 @@ def matmul(M,
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return gemm_autotune
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half()
use_autotune = args.use_autotune
def main(m: int = 16384,
n: int = 16384,
k: int = 16384,
use_autotune: bool = False,
with_roller: bool = True):
M, N, K = m, n, k
use_autotune = True
with_roller = args.with_roller
if use_autotune:
result = get_best_config(M, N, K, with_roller)
print(result.config)
......@@ -242,3 +227,22 @@ if __name__ == "__main__":
print(f"Ref latency: {ref_latency}")
print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}")
print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
import torch
import torch.backends
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
......@@ -99,7 +97,7 @@ def tl_matmul(
)
@T.prim_func
def main(
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
......@@ -156,36 +154,33 @@ def tl_matmul(
j % micro_size_y,
]
return main
return gemm_intrinsics
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
def ref_program(A, B):
return A @ B.T
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
profiler = kernel.get_profiler()
def main():
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
latency = profiler.do_bench(profiler.func, warmup=25)
profiler = kernel.get_profiler()
print(latency)
latency = profiler.do_bench(profiler.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
print(latency)
# Ensure that the latency is not None
assert latency is not None
def ref_program(A, B):
return A @ B.T
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
def gemm_schedule(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
......@@ -37,30 +36,37 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return gemm_schedule
func = matmul(1024, 1024, 1024, 128, 128, 32)
def main():
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
print(func)
artifact = tilelang.lower(func)
kernel = tilelang.compile(func, out_idx=-1)
profiler = Profiler(artifact.rt_mod, artifact.params, result_idx=[2])
import torch
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
c = profiler(a, b)
ref_c = a @ b
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
# Get CUDA Source
print(artifact.kernel_source)
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_autotune
import example_gemm_intrinsics
import example_gemm_schedule
import example_gemm
def test_example_gemm_autotune():
example_gemm_autotune.main()
def test_example_gemm_intrinsics():
example_gemm_intrinsics.main()
def test_example_gemm_schedule():
example_gemm_schedule.main()
def test_example_gemm():
example_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -14,7 +14,7 @@ def calc_diff(x, y):
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func
def main(
def gemm_fp8(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
......@@ -32,7 +32,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return gemm_fp8
def test_gemm_fp8(M, N, K, dtype):
......@@ -56,6 +56,10 @@ def test_gemm_fp8(M, N, K, dtype):
assert diff < 1e-3
if __name__ == "__main__":
def main():
test_gemm_fp8(1024, 1024, 1024, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 1024, 'e5m2_float8')
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -11,7 +11,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
update_interval = 128 // block_K if block_K < 128 else 1
@T.prim_func
def main(
def gemm_fp8_2xAcc(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
......@@ -43,7 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return gemm_fp8_2xAcc
def calc_diff(x, y):
......@@ -74,6 +74,10 @@ def test_gemm_fp8(M, N, K, dtype):
assert diff < 1e-3
if __name__ == "__main__":
def main():
test_gemm_fp8(1024, 1024, 8192, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 8192, 'e5m2_float8')
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -104,7 +104,7 @@ def tl_matmul(
)
@T.prim_func
def main(
def gemm_fp8_intrinsic(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
......@@ -172,7 +172,7 @@ def tl_matmul(
j % micro_size_y,
]
return main
return gemm_fp8_intrinsic
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
......@@ -201,7 +201,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler(A, B, C)
C = profiler(A, B)
latency = profiler.do_bench(warmup=25)
......@@ -215,12 +215,10 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
def main():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
main()
\ No newline at end of file
import tilelang.testing
import example_tilelang_gemm_fp8_2xAcc
import example_tilelang_gemm_fp8_intrinsic
import example_tilelang_gemm_fp8
def test_example_tilelang_gemm_fp8_2xAcc():
example_tilelang_gemm_fp8_2xAcc.main()
def test_example_tilelang_gemm_fp8_intrinsic():
example_tilelang_gemm_fp8_intrinsic.main()
def test_example_tilelang_gemm_fp8():
example_tilelang_gemm_fp8.main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment