"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5dbe4f5de6398159f8c2bedd371bc116683edbd3"
Commit eec47592 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[AutoTune] Support `with set_autotune_inputs` to set auto tuning input tensors (#632)

* [Refactor] Simplify and modularize autotuner implementation

- Removed unused imports and extensive code sections from the autotuner module to enhance readability and maintainability.
- Modularized the code by introducing new imports for autotuning and capturing functionalities, streamlining the overall structure.
- Improved logging setup and removed redundant timeout handling functions, focusing on core autotuning logic.
- Updated the AutoTuner class to better utilize the new modular structure, ensuring efficient performance during auto-tuning processes.

* [Refactor] Clean up and enhance capture and tuner modules

- Improved code readability by removing unnecessary blank lines and organizing imports in `capture.py` and `tuner.py`.
- Enhanced logging in the `AutoTuner` class to provide clearer warnings regarding the usage of `supply_prog` in the context of auto-tuning.
- Streamlined the `CaptureStack` class for better thread-local context management.

* lint fix

* [Refactor] Simplify configuration and autotuning logic in blocksparse GEMM example

- Updated `get_configs` function to reduce the number of configurations, enhancing performance and clarity.
- Removed the `get_best_config` function, integrating its logic directly into the `blocksparse_matmul` function with the `@autotune` decorator for streamlined autotuning.
- Adjusted the main function to directly utilize the autotuned kernel, simplifying the overall structure and improving readability.
- Deleted obsolete test file for autotuning decorator, cleaning up the codebase.

* [Refactor] Improve code formatting and readability in autotune test file

- Reformatted the `matmul` function and `get_configs` function for better readability by adjusting line breaks and indentation.
- Fixed a typo in the `enable_rasteration` parameter name to ensure consistency.
- Cleaned up unnecessary blank lines to enhance overall code clarity.

* Update example_blocksparse_gemm.py

* Update capture.py
parent b6fe9582
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import itertools import itertools
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
import torch import torch
...@@ -34,7 +33,7 @@ print(f"Target Block Sparsity: {sparsity}") ...@@ -34,7 +33,7 @@ print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n") print(f"Using Autotuner: {use_autotune}\n")
def get_configs(M, N, K): def get_configs():
block_M = [64, 128, 256] block_M = [64, 128, 256]
block_N = [64, 128, 256] block_N = [64, 128, 256]
block_K = [32, 64] block_K = [32, 64]
...@@ -90,55 +89,7 @@ def supply_program(params: List[KernelParam]): ...@@ -90,55 +89,7 @@ def supply_program(params: List[KernelParam]):
return input_tensors return input_tensors
def get_best_config(M, N, K): @tilelang.autotune(configs=get_configs(),)
# Define the kernel function to be tuned.
# Parameters like block_M, block_N, etc., are tuned by the AutoTuner.
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None):
return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num,
enable_rasteration)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K)
).set_compile_args(
out_idx=[-1], # Index of the output tensor
target="auto", # Automatically detect target
).set_profile_args(
# supply_type should not set here because we provide a custom supply
# function `supply_prog` and `supply_type` will be ignored.
# supply_prog: Provide the custom function to generate input tensors
# (A, B, BlockMask) for the kernel, allowing controlling sparsity via
# BlockMask generation.
supply_prog=supply_program,
# ref_prog: Using dense matmul (A @ B) as a placeholder reference.
# The 'correct' block-sparse reference (`ref_program` above) requires
# block_M, block_N, block_K parameters. However, these parameters are
# part of the configuration being *tuned* by the AutoTuner and cannot
# be fixed inputs to a static `ref_prog` function signature.
# This dense matmul serves only as a performance baseline.
ref_prog=lambda A, B, BlockMask: A @ B,
# skip_check: Set to True because the provided `ref_prog` does not
# compute the correct result for the block-sparse kernel.
skip_check=True,
# cache_input_tensors: Set to False because the shape of the BlockMask tensor
# (dependent on block_M, block_N, block_K being tuned) changes between
# different configurations. Reusing cached tensors from a previous
# configuration would lead to shape mismatches.
cache_input_tensors=False,
)
# Run the tuning process
return autotuner.run(warmup=3, rep=20)
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def blocksparse_matmul(M, def blocksparse_matmul(M,
N, N,
...@@ -192,22 +143,16 @@ def main(): ...@@ -192,22 +143,16 @@ def main():
# Run the autotuner to find the best kernel configuration and performance # Run the autotuner to find the best kernel configuration and performance
# get_best_config is expected to return an object containing the compiled kernel, # get_best_config is expected to return an object containing the compiled kernel,
# the best configuration found, latency, and reference latency. # the best configuration found, latency, and reference latency.
result = get_best_config(M, N, K) kernel = blocksparse_matmul(M, N, K)
# Extract results from the autotuner run best_config = kernel.config
kernel = result.kernel best_latency = kernel.latency
best_config = result.config block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[
block_M = best_config[0] "block_K"]
block_N = best_config[1]
block_K = best_config[2]
best_latency = result.latency
ref_latency = result.ref_latency
print(f"Best Config: {best_config}") print(f"Best Config: {best_config}")
print(f"Block Dimensions (BM, BN, BK): ({block_M}, {block_N}, {block_K})")
print(f"Sparsity Ratio: {sparsity}") print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms") print(f"Best Kernel Latency: {best_latency:.6f} ms")
print(f"Reference Latency: {ref_latency:.6f} ms")
else: else:
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
......
import torch import torch
import argparse import argparse
import itertools import itertools
import tilelang as tl import tilelang
from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate from tilelang.carver.template import ConvTemplate
...@@ -165,7 +164,7 @@ def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False): ...@@ -165,7 +164,7 @@ def get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller=False):
out_idx=[2], out_idx=[2],
target="auto", target="auto",
).set_profile_args( ).set_profile_args(
supply_type=tl.TensorSupplyType.Integer, supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=ref_prog, ref_prog=ref_prog,
skip_check=False, skip_check=False,
) )
...@@ -299,9 +298,9 @@ def main(n: int = 128, ...@@ -299,9 +298,9 @@ def main(n: int = 128,
kernel = result.kernel kernel = result.kernel
else: else:
config = get_heuristic_config() config = get_heuristic_config()
kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2]) kernel = tilelang.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench() tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_prog) ref_latency = profiler.do_bench(ref_prog)
profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2)
......
import itertools
import logging
import tilelang.testing
import tilelang.language as T
from tilelang.autotuner import autotune
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20
# Simple TIR Compute Expression
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = 0
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
for config in configs:
print(config)
else:
block_M = [64]
block_N = [64]
block_K = [32]
num_stages = [0, 1]
thread_num = [128]
enable_rasterization = [False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul(M, N, K, with_roller):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
@autotune(
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
)
@tilelang.jit(out_idx=[-1],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
enable_rasteration : bool
Whether to enable rasterization (swizzling) optimization.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return kernel()
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=True)
get_configs(8192, 8192, 8192, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=True)
matmul(8192, 8192, 8192, with_roller=False)
if __name__ == "__main__":
tilelang.testing.main()
import itertools
import logging
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs():
iter_params = dict(
block_M=[64],
block_N=[64],
block_K=[32],
num_stages=[0, 1],
thread_num=[128],
enable_rasterization=[False])
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M=128,
block_N=128,
block_K=32,
num_stages=0,
thread_num=128,
enable_rasteration=False):
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable (or disable) swizzling optimization
T.use_swizzle(panel_size=10, enable=enable_rasteration)
# Clear out the accumulation buffer
T.clear(C_local)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(
A[by * block_M, k * block_K],
A_shared,
)
# Load a sub-block of B from global memory into B_shared
T.copy(
B[bx * block_N, k * block_K],
B_shared,
)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_autotune(M: int, N: int, K: int):
import torch
a = torch.randn(M, K, dtype=torch.float16).cuda()
b = torch.randn(N, K, dtype=torch.float16).cuda()
with set_autotune_inputs([a, b]):
kernel = matmul(M, N, K)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def test_autotune_matmul():
run_autotune(8192, 8192, 8192)
if __name__ == "__main__":
tilelang.testing.main()
"""The auto-tune module for tilelang programs. from .tuner import (
autotune, # noqa: F401
This module provides functionality for auto-tuning tilelang programs, including JIT compilation AutoTuner, # noqa: F401
and performance optimization through configuration search. )
""" from .capture import (
set_autotune_inputs, # noqa: F401
import tilelang get_autotune_inputs, # noqa: F401
from tilelang import tvm as tvm
from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
from tqdm import tqdm
import logging
import functools
import concurrent.futures
import torch
import os
import sys
import signal
import json
import hashlib
import threading
import traceback
from pathlib import Path
from tilelang.env import (
TILELANG_CACHE_DIR,
TILELANG_AUTO_TUNING_CPU_UTILITIES,
TILELANG_AUTO_TUNING_CPU_COUNTS,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
is_cache_enabled,
) )
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.jit.param import _P, _RProg
from tilelang.version import __version__
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException("Operation timed out")
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
signal.alarm(0)
return result
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# Lazy handler initialization flag
_logger_handlers_initialized = False
def _init_logger_handlers():
global _logger_handlers_initialized
if _logger_handlers_initialized:
return
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
_logger_handlers_initialized = True
def get_available_cpu_count() -> int:
"""Gets the number of CPU cores available to the current process.
"""
try:
cpu_count = len(os.sched_getaffinity(0))
except AttributeError:
cpu_count = os.cpu_count()
return cpu_count
class AutoTuner:
"""Auto-tuner for tilelang programs.
This class handles the auto-tuning process by testing different configurations
and finding the optimal parameters for program execution.
Args:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
compile_args = CompileArgs()
profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs):
self.fn = fn
self.configs = configs
self.ref_latency_cache = None
self.jit_input_tensors = None
self.ref_input_tensors = None
self.jit_compile = None
@classmethod
def from_kernel(cls, kernel: Callable, configs):
"""Create an AutoTuner instance from a kernel function.
Args:
kernel: The kernel function to auto-tune.
configs: List of configurations to try.
Returns:
AutoTuner: A new AutoTuner instance.
"""
return cls(kernel, configs)
def set_compile_args(self,
out_idx: Union[List[int], int, None] = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Set compilation arguments for the auto-tuner.
Args:
out_idx: List of output tensor indices.
target: Target platform.
execution_backend: Execution backend to use for kernel execution.
target_host: Target host for cross-compilation.
verbose: Whether to enable verbose output.
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Returns:
AutoTuner: Self for method chaining.
"""
self.compile_args = CompileArgs(
out_idx=out_idx,
target=target,
execution_backend=execution_backend,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs)
return self
def set_profile_args(self,
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False):
"""Set profiling arguments for the auto-tuner.
Args:
supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
ref_prog: Reference program for validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for validation.
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors.
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutoTuner: Self for method chaining.
"""
self.profile_args = ProfileArgs(
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
warmup=warmup,
rep=rep,
timeout=timeout)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_profile_args` because "
"`supply_prog` is not None.")
return self
def set_kernel_parameters(self, parameters: Tuple[str, ...]):
# for cache key generation
self._kernel_parameters = parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
"""
# extract parameters from the function signature
op_parameters = []
for _, default_value in parameters.items():
if default_value.default is not inspect.Parameter.empty:
op_parameters.append(default_value.default)
if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters
func_source = inspect.getsource(self.fn)
key_data = {
"version": __version__,
"op_parameters": tuple(op_parameters),
"func_source": func_source,
"configs": self.configs,
"compile_args": hash(self.compile_args),
"profile_args": hash(self.profile_args),
}
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_string.encode()).hexdigest()
def _save_result_to_disk(self, key, result: AutotuneResult):
result.save_to_disk(self.cache_dir / key)
def _load_result_from_disk(self, key) -> AutotuneResult:
result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args)
return result
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
"""Run the auto-tuning process.
Args:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
_init_logger_handlers()
sig = inspect.signature(self.fn)
parameters = sig.parameters
if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters)
with self._lock:
if is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
return self._memory_cache[key]
# Then check disk cache
result = self._load_result_from_disk(key)
if result is not None:
# Populate memory cache with disk result
self._memory_cache[key] = result
return result
best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None
def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args
return compile_args.compile_program(self.fn(**config_arg))
if self.jit_compile is None:
self.jit_compile = _compile
def target_fn(jit_kernel: tilelang.JITKernel):
# Unpack the context
profile_args = self.profile_args
supply_type = profile_args.supply_type
skip_check = profile_args.skip_check
manual_check_prog = profile_args.manual_check_prog
cache_input_tensors = profile_args.cache_input_tensors
ref_prog = profile_args.ref_prog
supply_prog = profile_args.supply_prog
rtol = profile_args.rtol
atol = profile_args.atol
max_mismatched_ratio = profile_args.max_mismatched_ratio
profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type)
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
params = profiler._get_params(with_output=False)
if self.jit_input_tensors is None:
self.jit_input_tensors = jit_input_tensors_supply()
else:
# check if the cached tensors are compatible with the current configuration
assert len(params) == len(
self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
for p, c in zip(params, self.jit_input_tensors):
if not isinstance(c, torch.Tensor):
# skip non-tensor inputs checking
continue
# Check tensor compatibility using generator expression
def shape_equal(a, b):
return all(
a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var)
for a_dim, b_dim in zip(a.shape, b.shape))
if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning(
"\nIncompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
# otherwise, regenerate the input tensors for safety
self.jit_input_tensors = jit_input_tensors_supply()
break
else:
self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
config_args = []
for config in self.configs:
new_kwargs = {}
keys = config.keys()
for name, _ in parameters.items():
if name in config:
new_kwargs[name] = config[name]
unused_keys = set(keys) - set(new_kwargs.keys())
if len(unused_keys) > 0:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs)
if len(config_args) == 0:
raise ValueError("No configurations to tune, please check your `@autotune` decorator")
# check if the tunable arguments has been set.
# get the back config argument
top_config, *rest = config_args
if self._kernel_parameters is not None:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
# get the cpu count
available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count)
logger.info(
f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
else:
num_workers = max(1, int(available_cpu_count * cpu_utilizations))
logger.info(
f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
if max_cpu_count > 0 and num_workers > max_cpu_count:
logger.warning(
f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs"
)
num_workers = max_cpu_count
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
future_to_index = {}
def device_wrapper(func, device, **config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
for i, config_arg in enumerate(config_args):
future = pool.submit(
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
**config_arg,
)
futures.append(future)
future_to_index[future] = i
results_with_configs = []
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Compiling configurations"):
idx = future_to_index[future]
config = config_args[idx]
try:
result = future.result()
results_with_configs.append((result, config))
except Exception as e:
logger.debug(
f"Compilation failed for config {config} at index {idx} with error: {e}")
continue
ref_latency = None
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar:
jit_kernel, config = results_with_configs[i]
try:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException:
logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue
except Exception:
logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {traceback.format_exc()}")
continue
if latency < best_latency:
best_latency = latency
best_config = config
best_kernel = jit_kernel
progress_bar.set_postfix({"best_latency": best_latency})
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown()
if best_kernel is None:
error_msg = ("Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation.")
logger.error(error_msg)
raise RuntimeError(error_msg)
best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
)
autotuner_result = AutotuneResult(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
libcode=best_kernel.get_kernel_source(),
func=best_kernel.prim_func,
kernel=best_kernel)
if self.compile_args.execution_backend == "dlpack":
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result
return autotuner_result
def __call__(self) -> Any:
"""Make the AutoTuner callable, running the auto-tuning process.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
return self.run()
class _AutoTunerImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Union[Dict, Callable] = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = False
def __init__(self,
configs: Union[Dict, Callable],
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation.
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
return wrapper
def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
configs: Union[Dict, Callable],
# profile arguments
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
# compile arguments
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python
if enable_autotune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
else:
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
```
Parameters
----------
func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
configs : Dict or Callable
Configuration space to explore during auto-tuning.
warmup : int, optional
Number of warmup iterations before timing.
rep : int, optional
Number of repetitions for timing measurements.
timeout : int, optional
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional
Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional
Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None.
Returns
-------
Callable
Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function.
"""
if callable(func):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
raise ValueError(
"Use tilelang.autotune to decorate func without arguments is not supported yet.")
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation(
configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return configured_decorator
import threading
from typing import List, Any, Optional
# Use thread local to store the stack
# This is to avoid the cross-thread interference
_local = threading.local()
class CaptureStack:
"""
A simple stack implementation for capturing items in a thread-local context.
Used to manage a stack of items (e.g., input tensors) for auto-tuning capture.
"""
def __init__(self):
# Initialize an empty list to use as the stack
self.stack = []
def push(self, item):
"""
Push an item onto the top of the stack.
Args:
item: The item to be pushed onto the stack.
"""
self.stack.append(item)
def pop(self):
"""
Pop and return the top item from the stack.
Returns:
The item at the top of the stack.
Raises:
IndexError: If the stack is empty.
"""
return self.stack.pop()
def top(self):
"""
Return the item at the top of the stack without removing it.
Returns:
The item at the top of the stack.
Raises:
IndexError: If the stack is empty.
"""
return self.stack[-1]
def size(self):
"""
Return the number of items in the stack.
Returns:
int: The size of the stack.
"""
return len(self.stack)
def __len__(self):
"""
Return the number of items in the stack (len operator support).
Returns:
int: The size of the stack.
"""
return len(self.stack)
def __bool__(self):
"""
Return True if the stack is not empty, False otherwise.
Returns:
bool: Whether the stack contains any items.
"""
return bool(self.stack)
def _get_current_stack() -> CaptureStack:
if not hasattr(_local, "capture_stack"):
_local.capture_stack = CaptureStack()
return _local.capture_stack
class AutotuneInputsCapture:
__slots__ = ("tensors")
def __init__(self, tensors: List[Any]):
self.tensors = tensors
def __enter__(self) -> None:
_get_current_stack().push(self)
def __exit__(self, exc_type, exc_val, exc_tb):
_get_current_stack().pop()
def set_autotune_inputs(*args) -> AutotuneInputsCapture:
"""Set input tensors for auto-tuning.
This function creates a context manager for capturing input tensors
during the auto-tuning process. It supports both:
set_autotune_inputs(a, b, c)
set_autotune_inputs([a, b, c])
Args:
*args: Either a single list/tuple of tensors, or multiple tensor arguments.
Returns:
AutotuneInputsCapture: A context manager for auto-tuning inputs.
"""
if len(args) == 1 and isinstance(args[0], (list, tuple)):
tensors = list(args[0])
else:
tensors = list(args)
return AutotuneInputsCapture(tensors)
def get_autotune_inputs() -> Optional[List[Any]]:
"""
Get the current autotune inputs from the stack.
"""
stack = _get_current_stack()
return stack.top().tensors if stack else None
"""The auto-tune module for tilelang programs.
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
from tqdm import tqdm
import logging
import functools
import concurrent.futures
import torch
import os
import sys
import signal
import json
import hashlib
import threading
import traceback
from pathlib import Path
from tilelang.env import (
TILELANG_CACHE_DIR,
TILELANG_AUTO_TUNING_CPU_UTILITIES,
TILELANG_AUTO_TUNING_CPU_COUNTS,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
is_cache_enabled,
)
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.jit.param import _P, _RProg
from tilelang.version import __version__
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException("Operation timed out")
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
signal.alarm(0)
return result
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# Lazy handler initialization flag
_logger_handlers_initialized = False
def _init_logger_handlers():
global _logger_handlers_initialized
if _logger_handlers_initialized:
return
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
_logger_handlers_initialized = True
def get_available_cpu_count() -> int:
"""Gets the number of CPU cores available to the current process.
"""
try:
cpu_count = len(os.sched_getaffinity(0))
except AttributeError:
cpu_count = os.cpu_count()
return cpu_count or 1
class AutoTuner:
"""Auto-tuner for tilelang programs.
This class handles the auto-tuning process by testing different configurations
and finding the optimal parameters for program execution.
Args:
fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning.
"""
compile_args = CompileArgs()
profile_args = ProfileArgs()
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs):
self.fn = fn
self.configs = configs
self.ref_latency_cache = None
self.jit_input_tensors = None
self.ref_input_tensors = None
self.jit_compile = None
@classmethod
def from_kernel(cls, kernel: Callable, configs):
"""Create an AutoTuner instance from a kernel function.
Args:
kernel: The kernel function to auto-tune.
configs: List of configurations to try.
Returns:
AutoTuner: A new AutoTuner instance.
"""
return cls(kernel, configs)
def set_compile_args(self,
out_idx: Union[List[int], int, None] = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Set compilation arguments for the auto-tuner.
Args:
out_idx: List of output tensor indices.
target: Target platform.
execution_backend: Execution backend to use for kernel execution.
target_host: Target host for cross-compilation.
verbose: Whether to enable verbose output.
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Returns:
AutoTuner: Self for method chaining.
"""
self.compile_args = CompileArgs(
out_idx=out_idx,
target=target,
execution_backend=execution_backend,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs)
return self
def set_profile_args(self,
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False):
"""Set profiling arguments for the auto-tuner.
Args:
supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
ref_prog: Reference program for validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for validation.
atol: Absolute tolerance for validation.
max_mismatched_ratio: Maximum allowed mismatch ratio.
skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors.
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutoTuner: Self for method chaining.
"""
# If the program is under `with set_autotune_inputs` context,
# the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead.
if get_autotune_inputs() is not None:
if supply_prog is not None:
logger.warning(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
supply_prog = lambda _: get_autotune_inputs() # noqa: E731·
self.profile_args = ProfileArgs(
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
warmup=warmup,
rep=rep,
timeout=timeout)
# If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead.
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_profile_args` because "
"`supply_prog` is not None.")
return self
def set_kernel_parameters(self, parameters: Tuple[str, ...]):
# for cache key generation
self._kernel_parameters = parameters
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
"""
# extract parameters from the function signature
op_parameters = []
for _, default_value in parameters.items():
if default_value.default is not inspect.Parameter.empty:
op_parameters.append(default_value.default)
if self._kernel_parameters is not None:
op_parameters += self._kernel_parameters
func_source = inspect.getsource(self.fn)
key_data = {
"version": __version__,
"op_parameters": tuple(op_parameters),
"func_source": func_source,
"configs": self.configs,
"compile_args": hash(self.compile_args),
"profile_args": hash(self.profile_args),
}
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_string.encode()).hexdigest()
def _save_result_to_disk(self, key, result: AutotuneResult):
result.save_to_disk(self.cache_dir / key)
def _load_result_from_disk(self, key) -> AutotuneResult:
result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args)
return result
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
"""Run the auto-tuning process.
Args:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
_init_logger_handlers()
sig = inspect.signature(self.fn)
parameters = sig.parameters
if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters)
with self._lock:
if is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
return self._memory_cache[key]
# Then check disk cache
result = self._load_result_from_disk(key)
if result is not None:
# Populate memory cache with disk result
self._memory_cache[key] = result
return result
best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None
def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args
return compile_args.compile_program(self.fn(**config_arg))
if self.jit_compile is None:
self.jit_compile = _compile
def target_fn(jit_kernel: tilelang.JITKernel):
# Unpack the context
profile_args = self.profile_args
supply_type = profile_args.supply_type
skip_check = profile_args.skip_check
manual_check_prog = profile_args.manual_check_prog
cache_input_tensors = profile_args.cache_input_tensors
ref_prog = profile_args.ref_prog
supply_prog = profile_args.supply_prog
rtol = profile_args.rtol
atol = profile_args.atol
max_mismatched_ratio = profile_args.max_mismatched_ratio
profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type)
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
params = profiler._get_params(with_output=False)
if self.jit_input_tensors is None:
self.jit_input_tensors = jit_input_tensors_supply()
else:
# check if the cached tensors are compatible with the current configuration
assert len(params) == len(
self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
for p, c in zip(params, self.jit_input_tensors):
if not isinstance(c, torch.Tensor):
# skip non-tensor inputs checking
continue
# Check tensor compatibility using generator expression
def shape_equal(a, b):
return all(
a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var)
for a_dim, b_dim in zip(a.shape, b.shape))
if p.dtype != c.dtype or not shape_equal(p, c):
logger.warning(
"\nIncompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
# otherwise, regenerate the input tensors for safety
self.jit_input_tensors = jit_input_tensors_supply()
break
else:
self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
return latency, self.ref_latency_cache
config_args = []
for config in self.configs:
new_kwargs = {}
keys = config.keys()
for name, _ in parameters.items():
if name in config:
new_kwargs[name] = config[name]
unused_keys = set(keys) - set(new_kwargs.keys())
if len(unused_keys) > 0:
raise ValueError(f"Unused keys in config: {unused_keys}")
config_args.append(new_kwargs)
if len(config_args) == 0:
raise ValueError("No configurations to tune, please check your `@autotune` decorator")
# check if the tunable arguments has been set.
# get the back config argument
top_config, *rest = config_args
if self._kernel_parameters is not None:
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
# compile the kernel with the provided parameters
jit_kernel = self.jit_compile()
autotuner_result = AutotuneResult(
libcode=jit_kernel.get_kernel_source(),
func=jit_kernel.prim_func,
kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result
# get the cpu count
available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count)
logger.info(
f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
else:
num_workers = max(1, int(available_cpu_count * cpu_utilizations))
logger.info(
f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used"
)
if max_cpu_count > 0 and num_workers > max_cpu_count:
logger.warning(
f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs"
)
num_workers = max_cpu_count
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = []
future_to_index = {}
def device_wrapper(func, device, **config_arg):
torch.cuda.set_device(device)
return func(**config_arg)
for i, config_arg in enumerate(config_args):
future = pool.submit(
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
**config_arg,
)
futures.append(future)
future_to_index[future] = i
results_with_configs = []
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Compiling configurations"):
idx = future_to_index[future]
config = config_args[idx]
try:
result = future.result()
results_with_configs.append((result, config))
except Exception as e:
logger.debug(
f"Compilation failed for config {config} at index {idx} with error: {e}")
continue
ref_latency = None
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar:
jit_kernel, config = results_with_configs[i]
try:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException:
logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue
except Exception:
logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {traceback.format_exc()}")
continue
if latency < best_latency:
best_latency = latency
best_config = config
best_kernel = jit_kernel
progress_bar.set_postfix({"best_latency": best_latency})
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown()
if best_kernel is None:
error_msg = ("Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation.")
logger.error(error_msg)
raise RuntimeError(error_msg)
best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
)
autotuner_result = AutotuneResult(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
libcode=best_kernel.get_kernel_source(),
func=best_kernel.prim_func,
kernel=best_kernel)
if self.compile_args.execution_backend == "dlpack":
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result
return autotuner_result
def __call__(self) -> Any:
"""Make the AutoTuner callable, running the auto-tuning process.
Returns:
AutotuneResult: Results of the auto-tuning process.
"""
return self.run()
class _AutoTunerImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Union[Dict, Callable] = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = False
def __init__(self,
configs: Union[Dict, Callable],
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False) -> None:
"""Initialize the AutoTunerImplementation.
Args:
configs: Configuration space to explore during auto-tuning.
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
...
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
...
# Actual implementation of __call__
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
warmup = self.warmup
rep = self.rep
timeout = self.timeout
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._tuner_cache:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=self.configs).set_profile_args(
supply_type=self.supply_type,
ref_prog=self.ref_prog,
supply_prog=self.supply_prog,
rtol=self.rtol,
atol=self.atol,
max_mismatched_ratio=self.max_mismatched_ratio,
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
return wrapper
def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
configs: Union[Dict, Callable],
# profile arguments
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
# compile arguments
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None,
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
Tips:
- If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.
```python
if enable_autotune:
kernel = flashattn(batch, heads, seq_len, dim, is_causal)
else:
kernel = flashattn(
batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)
```
Parameters
----------
func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
configs : Dict or Callable
Configuration space to explore during auto-tuning.
warmup : int, optional
Number of warmup iterations before timing.
rep : int, optional
Number of repetitions for timing measurements.
timeout : int, optional
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional
Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional
Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None.
Returns
-------
Callable
Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function.
"""
if callable(func):
# Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
# This is a placeholder for a real auto tuner implementation
raise ValueError(
"Use tilelang.autotune to decorate func without arguments is not supported yet.")
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation(
configs=configs,
warmup=warmup,
rep=rep,
timeout=timeout,
supply_type=supply_type,
ref_prog=ref_prog,
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
)
return configured_decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment