Commit 316d3b97 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

add autotune to example_gemm.py (#252)

* add autotune to example_gemm.py

* add autotune to example_gemm.py

* add autotune to example_gemm.py

* add autotune to example_gemm.py
parent 2d0c4169
import tilelang
import argparse
import torch
import itertools
import tilelang as tl
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
from tilelang.autotuner import autotune, jit
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
def ref_program(A, B, C):
C += A @ B.T
def get_configs(M, N, K, with_roller=False):
if with_roller:
arch = CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def get_best_config(M, N, K, with_roller=False):
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return kernel()
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
print(kernel.get_kernel_source())
return kernel()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=True,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half()
c = torch.zeros(M, N).cuda().half()
configs = []
use_autotune = args.use_autotune
with_roller = args.with_roller
if use_autotune:
best_latency, best_config, ref_latency = get_best_config(M, N, K, with_roller)
func = matmul(M, N, K, *best_config)
else:
func = matmul(M, N, K, 128, 128, 32, 3, 128, True)
# print(func)
kernel = tl.compile(func, out_idx=-1)
out_c = kernel(a, b)
ref_c = a @ b.T + c
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
# print(kernel.get_kernel_source())
......@@ -169,7 +169,6 @@ def matmul(M, N, K, with_roller):
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
profiler="auto",
target="auto",
)
def kernel(
......
......@@ -268,7 +268,7 @@ def run_ctypes_kernel_do_bench(M,
profiler = matmul_kernel.get_profiler()
ctypes_latency = profiler.do_bench(func=matmul_kernel, profiler="torch")
ctypes_latency = profiler.do_bench(func=matmul_kernel)
print(f"Ctypes Latency: {ctypes_latency} ms")
assert ctypes_latency is not None
......
......@@ -270,7 +270,7 @@ def run_cython_kernel_do_bench(M,
cython_profiler = cython_matmul_kernel.get_profiler()
ctypes_profiler = ctypes_matmul_kernel.get_profiler()
cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel, profiler="torch")
cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel)
print(f"cython Latency: {cython_latency} ms")
# assert ctypes_latency is not None
......@@ -280,7 +280,7 @@ def run_cython_kernel_do_bench(M,
assert tvm_latency is not None
ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel, profiler="torch")
ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel)
print(f"ctypes Latency: {ctypes_latency} ms")
assert cython_latency is not None
......
......@@ -23,7 +23,6 @@ logging.basicConfig(
@dataclass(frozen=True)
class JITContext:
mod: tilelang.Profiler
out_idx: List[int]
supply_type: tilelang.TensorSupplyType
ref_prog: Callable
......@@ -31,7 +30,7 @@ class JITContext:
atol: float
max_mismatched_ratio: float
skip_check: bool
profiler: Literal['torch', 'tvm']
profiler: tilelang.Profiler
target: Literal['cuda', 'hip']
......@@ -58,8 +57,8 @@ class Autotuner:
self.jit_input_tensors = None
self.ref_input_tensors = None
def jit_compile(self, args: Any, **kwds: Any) -> JITContext:
jit_context = self.fn(*args, **kwds)
def jit_compile(self, config_arg) -> JITContext:
jit_context = self.fn(*config_arg)
return jit_context
def run(self, *args: Any, **kwds: Any) -> Any:
......@@ -72,7 +71,6 @@ class Autotuner:
def target_fn(jit_context):
# Unpack the context
mod = jit_context.mod
profiler = jit_context.profiler
skip_check = jit_context.skip_check
ref_prog = jit_context.ref_prog
......@@ -80,28 +78,26 @@ class Autotuner:
atol = jit_context.atol
max_mismatched_ratio = jit_context.max_mismatched_ratio
self.jit_input_tensors = mod._get_inputs(
self.jit_input_tensors = profiler._get_inputs(
with_output=profiler ==
"tvm") if self.jit_input_tensors is None else self.jit_input_tensors
if (not skip_check) and (ref_prog is not None):
mod.assert_allclose(
profiler.assert_allclose(
ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio)
latency = mod.do_bench(
mod.func,
latency = profiler.do_bench(
profiler.func,
n_warmup=self.warmup,
n_repeat=self.rep,
profiler=profiler,
input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = mod._get_inputs(
self.ref_input_tensors = profiler._get_inputs(
with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
self.ref_latency_cache = mod.do_bench(
self.ref_latency_cache = profiler.do_bench(
ref_prog,
n_warmup=self.warmup,
n_repeat=self.rep,
profiler="torch",
input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
......@@ -119,10 +115,7 @@ class Autotuner:
new_args = tuple(new_args)
config_args.append(new_args)
worker = partial(
self.jit_compile,
**kwds,
)
worker = partial(self.jit_compile, **kwds)
# 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9))
......@@ -205,7 +198,6 @@ def jit(out_idx: List[int],
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
profiler: Literal['auto', 'torch', 'tvm'] = 'auto',
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
def wrapper(fn: Callable):
......@@ -213,13 +205,11 @@ def jit(out_idx: List[int],
@wraps(fn)
def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
mod, params = tilelang.lower(fn(*args, **kwargs), target=target)
mod = tilelang.Profiler(mod, params, out_idx, supply_type)
kernel = tilelang.compile(
fn(*args, **kwargs), target=target, pass_configs={"tir.merge_static_smem": True})
profiler = kernel.get_profiler()
return JITContext(
mod=mod,
out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog,
......
"""The profiler and convert to torch utils"""
from typing import List, Literal, Optional, Callable, Any
from typing import List, Optional, Callable, Any
from functools import partial
import torch
from contextlib import suppress
......@@ -91,7 +91,9 @@ class Profiler:
lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
assert len(lib_outs) == len(ref_outs)
elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), "len(lib_outs) not equals to len(ref_outs) !"
# torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_outs, ref_outs):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
......@@ -133,9 +135,7 @@ class Profiler:
func = self.__call__
return func(*ins)
def determine_profiler(self,
func: Optional[Callable] = None,
profiler: Literal["torch", "tvm", "auto"] = "auto"):
def determine_profiler(self, func: Optional[Callable] = None):
"""Determines which profiler backend to use based on function type.
Args:
......@@ -145,12 +145,10 @@ class Profiler:
Returns:
str: The determined profiler type ("torch" or "tvm")
"""
if profiler == "auto":
if isinstance(func, tvm.runtime.Module):
return "tvm"
else:
return "torch"
return profiler
if isinstance(func, tvm.runtime.Module):
return "tvm"
else:
return "torch"
def do_bench(
self,
......@@ -159,7 +157,6 @@ class Profiler:
rep: int = 100,
n_warmup: int = 1,
n_repeat: int = 1,
profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None,
) -> float:
"""Benchmarks the execution time of a given function.
......@@ -176,7 +173,7 @@ class Profiler:
Returns:
float: Average execution time in milliseconds
"""
profiler = self.determine_profiler(func, profiler)
profiler = self.determine_profiler(func)
if profiler == "torch":
if func is None:
assert self.adapter is not None, "benchmarking function should be provided"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment