Commit 316d3b97 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

add autotune to example_gemm.py (#252)

* add autotune to example_gemm.py

* add autotune to example_gemm.py

* add autotune to example_gemm.py

* add autotune to example_gemm.py
parent 2d0c4169
import tilelang import argparse
import torch
import itertools
import tilelang as tl
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune, jit
from tilelang.carver.template import MatmulTemplate
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
def ref_program(A, B, C):
C += A @ B.T
def get_configs(M, N, K, with_roller=False):
if with_roller:
arch = CUDA("cuda")
topk = 10
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def get_best_config(M, N, K, with_roller=False):
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local) T.gemm(
A_shared,
T.copy(C_local, C[by * block_M, bx * block_N]) B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main return main
return kernel()
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b def matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
print(c) @T.prim_func
print(ref_c) def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) return main
# Get CUDA Source return kernel()
print(kernel.get_kernel_source())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=True,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=True,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half()
c = torch.zeros(M, N).cuda().half()
configs = []
use_autotune = args.use_autotune
with_roller = args.with_roller
if use_autotune:
best_latency, best_config, ref_latency = get_best_config(M, N, K, with_roller)
func = matmul(M, N, K, *best_config)
else:
func = matmul(M, N, K, 128, 128, 32, 3, 128, True)
# print(func)
kernel = tl.compile(func, out_idx=-1)
out_c = kernel(a, b)
ref_c = a @ b.T + c
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
# print(kernel.get_kernel_source())
...@@ -169,7 +169,6 @@ def matmul(M, N, K, with_roller): ...@@ -169,7 +169,6 @@ def matmul(M, N, K, with_roller):
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=True, skip_check=True,
profiler="auto",
target="auto", target="auto",
) )
def kernel( def kernel(
......
...@@ -268,7 +268,7 @@ def run_ctypes_kernel_do_bench(M, ...@@ -268,7 +268,7 @@ def run_ctypes_kernel_do_bench(M,
profiler = matmul_kernel.get_profiler() profiler = matmul_kernel.get_profiler()
ctypes_latency = profiler.do_bench(func=matmul_kernel, profiler="torch") ctypes_latency = profiler.do_bench(func=matmul_kernel)
print(f"Ctypes Latency: {ctypes_latency} ms") print(f"Ctypes Latency: {ctypes_latency} ms")
assert ctypes_latency is not None assert ctypes_latency is not None
......
...@@ -270,7 +270,7 @@ def run_cython_kernel_do_bench(M, ...@@ -270,7 +270,7 @@ def run_cython_kernel_do_bench(M,
cython_profiler = cython_matmul_kernel.get_profiler() cython_profiler = cython_matmul_kernel.get_profiler()
ctypes_profiler = ctypes_matmul_kernel.get_profiler() ctypes_profiler = ctypes_matmul_kernel.get_profiler()
cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel, profiler="torch") cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel)
print(f"cython Latency: {cython_latency} ms") print(f"cython Latency: {cython_latency} ms")
# assert ctypes_latency is not None # assert ctypes_latency is not None
...@@ -280,7 +280,7 @@ def run_cython_kernel_do_bench(M, ...@@ -280,7 +280,7 @@ def run_cython_kernel_do_bench(M,
assert tvm_latency is not None assert tvm_latency is not None
ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel, profiler="torch") ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel)
print(f"ctypes Latency: {ctypes_latency} ms") print(f"ctypes Latency: {ctypes_latency} ms")
assert cython_latency is not None assert cython_latency is not None
......
...@@ -23,7 +23,6 @@ logging.basicConfig( ...@@ -23,7 +23,6 @@ logging.basicConfig(
@dataclass(frozen=True) @dataclass(frozen=True)
class JITContext: class JITContext:
mod: tilelang.Profiler
out_idx: List[int] out_idx: List[int]
supply_type: tilelang.TensorSupplyType supply_type: tilelang.TensorSupplyType
ref_prog: Callable ref_prog: Callable
...@@ -31,7 +30,7 @@ class JITContext: ...@@ -31,7 +30,7 @@ class JITContext:
atol: float atol: float
max_mismatched_ratio: float max_mismatched_ratio: float
skip_check: bool skip_check: bool
profiler: Literal['torch', 'tvm'] profiler: tilelang.Profiler
target: Literal['cuda', 'hip'] target: Literal['cuda', 'hip']
...@@ -58,8 +57,8 @@ class Autotuner: ...@@ -58,8 +57,8 @@ class Autotuner:
self.jit_input_tensors = None self.jit_input_tensors = None
self.ref_input_tensors = None self.ref_input_tensors = None
def jit_compile(self, args: Any, **kwds: Any) -> JITContext: def jit_compile(self, config_arg) -> JITContext:
jit_context = self.fn(*args, **kwds) jit_context = self.fn(*config_arg)
return jit_context return jit_context
def run(self, *args: Any, **kwds: Any) -> Any: def run(self, *args: Any, **kwds: Any) -> Any:
...@@ -72,7 +71,6 @@ class Autotuner: ...@@ -72,7 +71,6 @@ class Autotuner:
def target_fn(jit_context): def target_fn(jit_context):
# Unpack the context # Unpack the context
mod = jit_context.mod
profiler = jit_context.profiler profiler = jit_context.profiler
skip_check = jit_context.skip_check skip_check = jit_context.skip_check
ref_prog = jit_context.ref_prog ref_prog = jit_context.ref_prog
...@@ -80,28 +78,26 @@ class Autotuner: ...@@ -80,28 +78,26 @@ class Autotuner:
atol = jit_context.atol atol = jit_context.atol
max_mismatched_ratio = jit_context.max_mismatched_ratio max_mismatched_ratio = jit_context.max_mismatched_ratio
self.jit_input_tensors = mod._get_inputs( self.jit_input_tensors = profiler._get_inputs(
with_output=profiler == with_output=profiler ==
"tvm") if self.jit_input_tensors is None else self.jit_input_tensors "tvm") if self.jit_input_tensors is None else self.jit_input_tensors
if (not skip_check) and (ref_prog is not None): if (not skip_check) and (ref_prog is not None):
mod.assert_allclose( profiler.assert_allclose(
ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio) ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio)
latency = mod.do_bench( latency = profiler.do_bench(
mod.func, profiler.func,
n_warmup=self.warmup, n_warmup=self.warmup,
n_repeat=self.rep, n_repeat=self.rep,
profiler=profiler,
input_tensors=self.jit_input_tensors) input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = mod._get_inputs( self.ref_input_tensors = profiler._get_inputs(
with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
self.ref_latency_cache = mod.do_bench( self.ref_latency_cache = profiler.do_bench(
ref_prog, ref_prog,
n_warmup=self.warmup, n_warmup=self.warmup,
n_repeat=self.rep, n_repeat=self.rep,
profiler="torch",
input_tensors=self.ref_input_tensors) input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache return latency, self.ref_latency_cache
...@@ -119,10 +115,7 @@ class Autotuner: ...@@ -119,10 +115,7 @@ class Autotuner:
new_args = tuple(new_args) new_args = tuple(new_args)
config_args.append(new_args) config_args.append(new_args)
worker = partial( worker = partial(self.jit_compile, **kwds)
self.jit_compile,
**kwds,
)
# 90% utilization # 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9)) num_workers = max(1, int(os.cpu_count() * 0.9))
...@@ -205,7 +198,6 @@ def jit(out_idx: List[int], ...@@ -205,7 +198,6 @@ def jit(out_idx: List[int],
atol: float = 1e-2, atol: float = 1e-2,
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
profiler: Literal['auto', 'torch', 'tvm'] = 'auto',
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable: target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
def wrapper(fn: Callable): def wrapper(fn: Callable):
...@@ -213,13 +205,11 @@ def jit(out_idx: List[int], ...@@ -213,13 +205,11 @@ def jit(out_idx: List[int],
@wraps(fn) @wraps(fn)
def decorator(*args, **kwargs) -> float: def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion # Enabling Efficient Fusion
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}): kernel = tilelang.compile(
mod, params = tilelang.lower(fn(*args, **kwargs), target=target) fn(*args, **kwargs), target=target, pass_configs={"tir.merge_static_smem": True})
profiler = kernel.get_profiler()
mod = tilelang.Profiler(mod, params, out_idx, supply_type)
return JITContext( return JITContext(
mod=mod,
out_idx=out_idx, out_idx=out_idx,
supply_type=supply_type, supply_type=supply_type,
ref_prog=ref_prog, ref_prog=ref_prog,
......
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from typing import List, Literal, Optional, Callable, Any from typing import List, Optional, Callable, Any
from functools import partial from functools import partial
import torch import torch
from contextlib import suppress from contextlib import suppress
...@@ -91,7 +91,9 @@ class Profiler: ...@@ -91,7 +91,9 @@ class Profiler:
lib_outs = [lib_outs] lib_outs = [lib_outs]
if isinstance(ref_outs, torch.Tensor): if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs] ref_outs = [ref_outs]
assert len(lib_outs) == len(ref_outs) elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), "len(lib_outs) not equals to len(ref_outs) !"
# torch.set_printoptions(edgeitems=torch.inf) # torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_outs, ref_outs): for lhs, rhs in zip(lib_outs, ref_outs):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
...@@ -133,9 +135,7 @@ class Profiler: ...@@ -133,9 +135,7 @@ class Profiler:
func = self.__call__ func = self.__call__
return func(*ins) return func(*ins)
def determine_profiler(self, def determine_profiler(self, func: Optional[Callable] = None):
func: Optional[Callable] = None,
profiler: Literal["torch", "tvm", "auto"] = "auto"):
"""Determines which profiler backend to use based on function type. """Determines which profiler backend to use based on function type.
Args: Args:
...@@ -145,12 +145,10 @@ class Profiler: ...@@ -145,12 +145,10 @@ class Profiler:
Returns: Returns:
str: The determined profiler type ("torch" or "tvm") str: The determined profiler type ("torch" or "tvm")
""" """
if profiler == "auto":
if isinstance(func, tvm.runtime.Module): if isinstance(func, tvm.runtime.Module):
return "tvm" return "tvm"
else: else:
return "torch" return "torch"
return profiler
def do_bench( def do_bench(
self, self,
...@@ -159,7 +157,6 @@ class Profiler: ...@@ -159,7 +157,6 @@ class Profiler:
rep: int = 100, rep: int = 100,
n_warmup: int = 1, n_warmup: int = 1,
n_repeat: int = 1, n_repeat: int = 1,
profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None, input_tensors: List[torch.Tensor] = None,
) -> float: ) -> float:
"""Benchmarks the execution time of a given function. """Benchmarks the execution time of a given function.
...@@ -176,7 +173,7 @@ class Profiler: ...@@ -176,7 +173,7 @@ class Profiler:
Returns: Returns:
float: Average execution time in milliseconds float: Average execution time in milliseconds
""" """
profiler = self.determine_profiler(func, profiler) profiler = self.determine_profiler(func)
if profiler == "torch": if profiler == "torch":
if func is None: if func is None:
assert self.adapter is not None, "benchmarking function should be provided" assert self.adapter is not None, "benchmarking function should be provided"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment