Commit 541e1685 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Refactor] Enhance Autotune (#266)

* add autotune to example_gemm.py

* format init.py
parent 8ad53855
......@@ -3,7 +3,7 @@ import torch
import itertools
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import autotune, jit
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
......@@ -79,26 +79,6 @@ def get_configs(M, N, K, with_roller=False):
def get_best_config(M, N, K, with_roller=False):
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
......@@ -138,7 +118,15 @@ def get_best_config(M, N, K, with_roller=False):
return main
return kernel()
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
return autotuner.run(warmup=3, rep=20)
def matmul(M,
......@@ -200,19 +188,16 @@ if __name__ == "__main__":
M, N, K = args.m, args.n, args.k
a = torch.randn(M, K).cuda().half()
b = torch.randn(N, K).cuda().half()
c = torch.zeros(M, N).cuda().half()
configs = []
use_autotune = args.use_autotune
with_roller = args.with_roller
if use_autotune:
best_latency, best_config, ref_latency = get_best_config(M, N, K, with_roller)
func = matmul(M, N, K, *best_config)
result = get_best_config(M, N, K, with_roller)
print(f"best latency {result.latency}")
kernel = result.kernel
else:
func = matmul(M, N, K, 128, 128, 32, 3, 128, True)
kernel = tl.compile(matmul(M, N, K, 128, 128, 32, 3, 128, True), out_idx=-1)
# print(func)
kernel = tl.compile(func, out_idx=-1)
out_c = kernel(a, b)
ref_c = a @ b.T + c
ref_c = ref_program(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
# print(kernel.get_kernel_source())
......@@ -4,7 +4,7 @@ import logging
import tilelang as tl
import tilelang.testing
import tilelang.language as T
from tilelang.autotuner import autotune, jit
from tilelang.autotuner import AutoTuner
# Configure logger
logger = logging.getLogger(__name__)
......@@ -151,26 +151,6 @@ def matmul(M, N, K, with_roller):
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K, with_roller),
keys=[
"block_M",
"block_N",
"block_K",
"num_stages",
"thread_num",
"enable_rasteration",
],
warmup=3,
rep=5,
)
@jit(
out_idx=[2],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=True,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
......@@ -268,14 +248,24 @@ def matmul(M, N, K, with_roller):
return main
return kernel()
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
return autotuner.run(warmup=3, rep=20)
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=True)
get_configs(8192, 8192, 8192, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=True)
matmul(8192, 8192, 8192, with_roller=False)
......
import itertools
import logging
import tilelang as tl
import tilelang.testing
import tilelang.language as T
from tilelang.autotuner import jit, autotune
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64]
block_N = [64]
block_K = [32]
num_stages = [0, 1]
thread_num = [128]
enable_rasterization = [False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
target="auto",
)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return kernel()
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=True)
get_configs(8192, 8192, 8192, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=True)
matmul(8192, 8192, 8192, with_roller=False)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -3,14 +3,13 @@
import tilelang
from tilelang import tvm as tvm
import inspect
from functools import wraps
from typing import Any, Callable, List, Literal
from functools import wraps, partial
from typing import Callable, List, Literal, Any
from tqdm import tqdm
import logging
from dataclasses import dataclass
import concurrent.futures
import os
from functools import partial
logger = logging.getLogger(__name__)
......@@ -34,40 +33,65 @@ class JITContext:
target: Literal['cuda', 'hip']
class Autotuner:
@dataclass(frozen=True)
class AutotuneResult:
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
class AutoTuner:
def __init__(
self,
fn: Callable,
configs: Any,
keys: List[str],
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
):
def __init__(self, fn: Callable, configs):
self.fn = fn
self.configs = configs
self.keys = keys
self.warmup = warmup
self.rep = rep
self.timeout = timeout
# Precompute cached variables
self.ref_latency_cache = None
self.jit_input_tensors = None
self.ref_input_tensors = None
def jit_compile(self, config_arg) -> JITContext:
jit_context = self.fn(*config_arg)
return jit_context
@classmethod
def from_kernel(cls, kernel: Callable, configs):
return cls(kernel, configs)
def set_compile_args(self,
out_idx: List[int],
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal,
ref_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
target: Literal['auto', 'cuda', 'hip'] = 'auto'):
def _compile(*config_arg):
kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
profiler = kernel.get_profiler()
jit_context = JITContext(
out_idx=out_idx,
supply_type=supply_type,
ref_prog=ref_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
profiler=profiler,
target=target)
return jit_context
self.jit_compile = _compile
return self
def run(self, *args: Any, **kwds: Any) -> Any:
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100):
sig = inspect.signature(self.fn)
bound_args = sig.bind(*args, **kwds)
keys = list(sig.parameters.keys())
bound_args = sig.bind()
bound_args.apply_defaults()
best_latency = 1e8
best_config = None
best_jit_context = None
def target_fn(jit_context):
# Unpack the context
......@@ -87,49 +111,38 @@ class Autotuner:
ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
profiler.func,
n_warmup=self.warmup,
n_repeat=self.rep,
input_tensors=self.jit_input_tensors)
profiler.func, n_warmup=warmup, n_repeat=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = profiler._get_inputs(
with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
self.ref_latency_cache = profiler.do_bench(
ref_prog,
n_warmup=self.warmup,
n_repeat=self.rep,
input_tensors=self.ref_input_tensors)
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
# Parallel compilation
config_args = []
for config in self.configs:
new_args = []
for name, value in bound_args.arguments.items():
if name not in self.keys:
if name not in keys:
new_args.append(value)
else:
new_args.append(config[name])
new_args = tuple(new_args)
config_args.append(new_args)
worker = partial(self.jit_compile, **kwds)
# 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
# Submit all compilation jobs
futures = []
future_to_index = {} # Track which future corresponds to which config
future_to_index = {}
for i, config_arg in enumerate(config_args):
future = pool.submit(worker, config_arg)
future = pool.submit(
self.jit_compile,
*config_arg,
)
futures.append(future)
future_to_index[future] = i
# Process results with error handling
results_with_configs = []
for future in tqdm(
concurrent.futures.as_completed(futures),
......@@ -164,28 +177,34 @@ class Autotuner:
if latency < best_latency:
best_latency = latency
best_config = config
best_jit_context = jit_context
progress_bar.set_postfix({"best_latency": best_latency})
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown()
return best_latency, best_config, ref_latency
return AutotuneResult(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
libcode=best_jit_context.profiler.func.lib_code,
func=self.fn(*best_config),
kernel=best_jit_context.profiler.func)
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.run(*args, **kwds)
def __call__(self) -> Any:
return self.run()
def autotune(configs: Any,
keys: List[str],
warmup: int = 25,
rep: int = 100,
timeout: int = 100) -> Callable:
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable:
"""
Decorator for tilelang program
"""
def decorator(fn: Callable) -> Autotuner:
return Autotuner(fn, configs=configs, keys=keys, warmup=warmup, rep=rep, timeout=timeout)
def decorator(fn: Callable) -> AutoTuner:
autotuner = AutoTuner(fn, configs=configs)
autotuner.jit_compile = fn
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
return autotuner
return decorator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment