Commit 7171aff6 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Autotune] Introduce cache mechanism for auto tuner (#527)

* [Enhancement] Add commit ID to versioning and improve logging initialization

* Updated `get_tilelang_version` to include an optional commit ID in the version string.
* Enhanced the `TileLangBuilPydCommand` to write the version with commit ID to the VERSION file during the build process.
* Introduced a new function `get_git_commit_id` in `version.py` to retrieve the current git commit hash.
* Refactored logger initialization in `autotuner/__init__.py` to ensure handlers are set up only once, improving performance and clarity.
* Minor fixes in `flatten_buffer.cc` and `kernel_cache.py` for better handling of versioning and logging.

* [Refactor] Enhance AutoTuner and JITKernel for improved performance and caching

* Refactored the AutoTuner class to include new methods for setting compilation and profiling arguments, enhancing configurability.
* Introduced caching mechanisms for tuning results, allowing for faster retrieval of previously computed configurations.
* Updated JITKernel to store tuning results, including latency and configuration details, improving the kernel's performance tracking.
* Added new methods for generating cache keys and saving/loading results to/from disk, streamlining the tuning process.
* Enhanced the overall structure and readability of the autotuning logic, ensuring better maintainability and clarity.
* Minor adjustments in related modules to support the new caching and profiling features.

* [Refactor] Clean up code formatting and improve readability in AutoTuner and related modules

* Consolidated import statements and removed unnecessary line breaks for better readability.
* Standardized function argument formatting across the AutoTuner and CompileArgs classes.
* Enhanced consistency in the use of whitespace and indentation throughout the codebase.
* Minor adjustments in the Profiler and JITKernel classes to improve clarity and maintainability.
* Ensured that all changes adhere to the project's coding style guidelines.

* [Refactor] Remove redundant type hints in AutoTuner modules

* Simplified import statements in `__init__.py` and `param.py` by removing unnecessary duplicate type hints for `Any`.
* Improved code readability and maintainability by streamlining type imports across the AutoTuner module.

* [Refactor] Update AutoTuner configuration for improved profiling and target detection

* Enhanced the AutoTuner configuration across multiple examples by adding `set_profile_args` to better manage profiling settings.
* Standardized the use of `target="auto"` in compile arguments to ensure automatic target detection.
* Removed redundant target specifications in certain instances to streamline the configuration process.
* Improved overall clarity and maintainability of the autotuning logic in various example scripts.

* [Refactor] Simplify code formatting and improve readability in example scripts

* Consolidated function argument formatting in `benchmark_mla_decode_amd_tilelang.py`, `example_elementwise_add.py`, and `performance.py` for better clarity.
* Removed unnecessary line breaks and standardized argument placement across multiple files.
* Enhanced overall code readability and maintainability in autotuning examples and performance scripts.

* [Refactor] Update JIT decorator usage across multiple files

* Removed redundant parameters from the JIT decorator in various benchmark and example scripts, simplifying the code.
* Standardized the import of the JIT decorator from `tilelang`, enhancing consistency across the codebase.
* Improved overall readability and maintainability by consolidating import statements and cleaning up function definitions.

* [Refactor] Standardize JIT decorator formatting across benchmark and example scripts

* Simplified the formatting of the JIT decorator in multiple files by removing unnecessary line breaks.
* Enhanced code readability and consistency in the usage of the JIT decorator across benchmark and example scripts.
* Improved overall maintainability by ensuring uniformity in function definitions and decorator usage.
parent 09581e4e
...@@ -93,10 +93,10 @@ from .layout import ( ...@@ -93,10 +93,10 @@ from .layout import (
) )
from . import ( from . import (
transform, # noqa: F401 transform, # noqa: F401
autotuner, # noqa: F401
language, # noqa: F401 language, # noqa: F401
engine, # noqa: F401 engine, # noqa: F401
) )
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401 from .transform import PassConfigKey # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
......
...@@ -6,18 +6,28 @@ and performance optimization through configuration search. ...@@ -6,18 +6,28 @@ and performance optimization through configuration search.
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
import inspect import inspect
from functools import wraps, partial from functools import partial
from typing import Callable, List, Literal, Any, Optional, Union from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
from tqdm import tqdm from tqdm import tqdm
import logging import logging
import functools import functools
from dataclasses import dataclass
import concurrent.futures import concurrent.futures
import torch import torch
import os import os
import sys import sys
import signal import signal
import json
import hashlib
import threading
from pathlib import Path
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.jit.param import _P, _RProg
from tilelang.version import __version__
class TimeoutException(Exception): class TimeoutException(Exception):
...@@ -64,90 +74,15 @@ def _init_logger_handlers(): ...@@ -64,90 +74,15 @@ def _init_logger_handlers():
_logger_handlers_initialized = True _logger_handlers_initialized = True
@dataclass(frozen=True) def get_available_cpu_count() -> int:
class JITContext: """Gets the number of CPU cores available to the current process.
"""Context object for Just-In-Time compilation settings.
Attributes:
out_idx: List of output tensor indices.
ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
rtol: Relative tolerance for output validation.
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
cache_input_tensors: Whether to cache input tensors for each compilation.
kernel: JITKernel instance for performance measurement.
supply_type: Type of tensor supply mechanism.
target: Target platform ('cuda' or 'hip').
"""
out_idx: List[int]
ref_prog: Callable
supply_prog: Callable
rtol: float
atol: float
max_mismatched_ratio: float
skip_check: bool
manual_check_prog: Callable
cache_input_tensors: bool
kernel: tilelang.JITKernel
supply_type: tilelang.TensorSupplyType
target: Literal['cuda', 'hip']
@dataclass(frozen=True)
class AutotuneResult:
"""Results from auto-tuning process.
Attributes:
latency: Best achieved execution latency.
config: Configuration that produced the best result.
ref_latency: Reference implementation latency.
libcode: Generated library code.
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
@dataclass(frozen=True)
class CompileArgs:
"""Compile arguments for the auto-tuner.
Attributes:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
out_idx: Union[List[int], int] = -1
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
""" """
try:
cpu_count = len(os.sched_getaffinity(0))
except AttributeError:
cpu_count = os.cpu_count()
out_idx: Union[List[int], int] = -1 return cpu_count
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
target: Literal['auto', 'cuda', 'hip'] = 'auto'
class AutoTuner: class AutoTuner:
...@@ -160,6 +95,12 @@ class AutoTuner: ...@@ -160,6 +95,12 @@ class AutoTuner:
fn: The function to be auto-tuned. fn: The function to be auto-tuned.
configs: List of configurations to try during auto-tuning. configs: List of configurations to try during auto-tuning.
""" """
compile_args = CompileArgs()
profile_args = ProfileArgs()
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -168,7 +109,6 @@ class AutoTuner: ...@@ -168,7 +109,6 @@ class AutoTuner:
self.jit_input_tensors = None self.jit_input_tensors = None
self.ref_input_tensors = None self.ref_input_tensors = None
self.jit_compile = None self.jit_compile = None
self.compile_args = CompileArgs()
@classmethod @classmethod
def from_kernel(cls, kernel: Callable, configs): def from_kernel(cls, kernel: Callable, configs):
...@@ -185,6 +125,38 @@ class AutoTuner: ...@@ -185,6 +125,38 @@ class AutoTuner:
def set_compile_args(self, def set_compile_args(self,
out_idx: Union[List[int], int, None] = None, out_idx: Union[List[int], int, None] = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Set compilation arguments for the auto-tuner.
Args:
out_idx: List of output tensor indices.
target: Target platform.
execution_backend: Execution backend to use for kernel execution.
target_host: Target host for cross-compilation.
verbose: Whether to enable verbose output.
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Returns:
AutoTuner: Self for method chaining.
"""
self.compile_args = CompileArgs(
out_idx=out_idx,
target=target,
execution_backend=execution_backend,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs)
return self
def set_profile_args(self,
warmup: int = 25,
rep: int = 100,
timeout: int = 30,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
ref_prog: Callable = None, ref_prog: Callable = None,
supply_prog: Callable = None, supply_prog: Callable = None,
...@@ -193,12 +165,10 @@ class AutoTuner: ...@@ -193,12 +165,10 @@ class AutoTuner:
max_mismatched_ratio: float = 0.01, max_mismatched_ratio: float = 0.01,
skip_check: bool = False, skip_check: bool = False,
manual_check_prog: Callable = None, manual_check_prog: Callable = None,
cache_input_tensors: bool = True, cache_input_tensors: bool = True):
target: Literal['auto', 'cuda', 'hip'] = 'auto'): """Set profiling arguments for the auto-tuner.
"""Set compilation arguments for the auto-tuner.
Args: Args:
out_idx: List of output tensor indices.
supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided. supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
ref_prog: Reference program for validation. ref_prog: Reference program for validation.
supply_prog: Supply program for input tensors. supply_prog: Supply program for input tensors.
...@@ -208,13 +178,14 @@ class AutoTuner: ...@@ -208,13 +178,14 @@ class AutoTuner:
skip_check: Whether to skip validation. skip_check: Whether to skip validation.
manual_check_prog: Manual check program for validation. manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors. cache_input_tensors: Whether to cache input tensors.
target: Target platform. warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
Returns: Returns:
AutoTuner: Self for method chaining. AutoTuner: Self for method chaining.
""" """
self.compile_args = CompileArgs( self.profile_args = ProfileArgs(
out_idx=out_idx,
supply_type=supply_type, supply_type=supply_type,
ref_prog=ref_prog, ref_prog=ref_prog,
supply_prog=supply_prog, supply_prog=supply_prog,
...@@ -224,16 +195,40 @@ class AutoTuner: ...@@ -224,16 +195,40 @@ class AutoTuner:
skip_check=skip_check, skip_check=skip_check,
manual_check_prog=manual_check_prog, manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors, cache_input_tensors=cache_input_tensors,
target=target) warmup=warmup,
rep=rep,
timeout=timeout)
# If a custom `supply_prog`` is provided, the profiler's `supply_type` setting # If a custom `supply_prog` is provided, the profiler's `supply_type` setting
# becomes ineffective. The custom supply program will be used instead. # becomes ineffective. The custom supply program will be used instead.
if ref_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `set_compile_args` because " logger.warning("Ignoring `supply_type` passed to `set_profile_args` because "
"`ref_prog` is not None.") "`supply_prog` is not None.")
return self return self
def generate_cache_key(self) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
"""
func_source = inspect.getsource(self.fn)
key_data = {
"version": __version__,
"func_source": func_source,
"configs": self.configs,
"compile_args": hash(self.compile_args),
"profile_args": hash(self.profile_args),
}
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_string.encode()).hexdigest()
def _save_result_to_disk(self, key, result: AutotuneResult):
result.save_to_disk(self.cache_dir / key)
def _load_result_from_disk(self, key) -> AutotuneResult:
result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args)
return result
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
"""Run the auto-tuning process. """Run the auto-tuning process.
...@@ -246,50 +241,50 @@ class AutoTuner: ...@@ -246,50 +241,50 @@ class AutoTuner:
AutotuneResult: Results of the auto-tuning process. AutotuneResult: Results of the auto-tuning process.
""" """
_init_logger_handlers() _init_logger_handlers()
key = self.generate_cache_key()
with self._lock:
if is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
return self._memory_cache[key]
# Then check disk cache
result = self._load_result_from_disk(key)
if result is not None:
# Populate memory cache with disk result
self._memory_cache[key] = result
return result
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
keys = list(sig.parameters.keys()) parameters = sig.parameters
bound_args = sig.bind() best_latency: float = 1e8
bound_args.apply_defaults() best_config: Optional[Dict[str, Any]] = None
best_latency = 1e8 best_kernel: Optional[tilelang.JITKernel] = None
best_config = None
best_jit_context = None def _compile(**config_arg) -> tilelang.JITKernel:
def _compile(*config_arg):
compile_args = self.compile_args compile_args = self.compile_args
kernel = tilelang.compile( return compile_args.compile_program(self.fn(**config_arg))
self.fn(*config_arg), out_idx=compile_args.out_idx, target=compile_args.target)
jit_context = JITContext(
out_idx=compile_args.out_idx,
ref_prog=compile_args.ref_prog,
supply_prog=compile_args.supply_prog,
rtol=compile_args.rtol,
atol=compile_args.atol,
max_mismatched_ratio=compile_args.max_mismatched_ratio,
skip_check=compile_args.skip_check,
manual_check_prog=compile_args.manual_check_prog,
cache_input_tensors=compile_args.cache_input_tensors,
kernel=kernel,
supply_type=compile_args.supply_type,
target=compile_args.target)
return jit_context
if self.jit_compile is None: if self.jit_compile is None:
self.jit_compile = _compile self.jit_compile = _compile
def target_fn(jit_context: JITContext): def target_fn(jit_kernel: tilelang.JITKernel):
# Unpack the context # Unpack the context
kernel = jit_context.kernel profile_args = self.profile_args
supply_type = jit_context.supply_type supply_type = profile_args.supply_type
skip_check = jit_context.skip_check skip_check = profile_args.skip_check
manual_check_prog = jit_context.manual_check_prog manual_check_prog = profile_args.manual_check_prog
cache_input_tensors = jit_context.cache_input_tensors cache_input_tensors = profile_args.cache_input_tensors
ref_prog = jit_context.ref_prog ref_prog = profile_args.ref_prog
supply_prog = jit_context.supply_prog supply_prog = profile_args.supply_prog
rtol = jit_context.rtol rtol = profile_args.rtol
atol = jit_context.atol atol = profile_args.atol
max_mismatched_ratio = jit_context.max_mismatched_ratio max_mismatched_ratio = profile_args.max_mismatched_ratio
profiler = kernel.get_profiler(tensor_supply_type=supply_type) profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type)
# Factory functions for generating input tensors. # Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`) # This encapsulates the logic of using either a custom supply program (`supply_prog`)
...@@ -308,21 +303,18 @@ class AutoTuner: ...@@ -308,21 +303,18 @@ class AutoTuner:
ref_input_tensors_supply = get_input_tensors_supply(with_output=False) ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors: if cache_input_tensors:
jit_input_tensors = jit_input_tensors_supply() if supply_prog is not None:
if self.jit_input_tensors is not None: logger.warning(
if not check_tensor_list_compatibility(self.jit_input_tensors, "Incompatible input tensor properties detected between cached tensors and "
jit_input_tensors): "tensors regenerated for the current configuration trial. "
logger.warning( "This can happen if different tuning configurations require different input shapes/dtypes "
"Incompatible input tensor properties detected between cached tensors and " "and input tensor caching is enabled.\n"
"tensors regenerated for the current configuration trial. " "To ensure fresh, compatible inputs are generated for every trial "
"This can happen if different tuning configurations require different input shapes/dtypes " "you can disable caching by setting:\n"
"and input tensor caching is enabled.\n" " `cache_input_tensors=False`\n"
"To ensure fresh, compatible inputs are generated for every trial " "within your `.set_compile_args(...)` call.\n")
"you can disable caching by setting:\n" self.jit_input_tensors = jit_input_tensors_supply(
" `cache_input_tensors=False`\n" ) if self.jit_input_tensors is None else self.jit_input_tensors
"within your `.set_compile_args(...)` call.\n")
self.jit_input_tensors = jit_input_tensors
self.jit_input_tensors = jit_input_tensors
else: else:
self.jit_input_tensors = jit_input_tensors_supply() self.jit_input_tensors = jit_input_tensors_supply()
...@@ -350,30 +342,25 @@ class AutoTuner: ...@@ -350,30 +342,25 @@ class AutoTuner:
config_args = [] config_args = []
for config in self.configs: for config in self.configs:
new_args = [] new_kwargs = {}
for name, value in bound_args.arguments.items(): for name, _ in parameters.items():
if name not in keys: if name in config:
new_args.append(value) new_kwargs[name] = config[name]
else: config_args.append(new_kwargs)
if name not in config:
raise ValueError(f"Configuration {config} does not contain key {name}")
new_args.append(config[name])
new_args = tuple(new_args)
config_args.append(new_args)
num_workers = max(1, int(get_available_cpu_count() * 0.9)) num_workers = max(1, int(get_available_cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
futures = [] futures = []
future_to_index = {} future_to_index = {}
def device_wrapper(func, device, *config_arg): def device_wrapper(func, device, **config_arg):
torch.cuda.set_device(device) torch.cuda.set_device(device)
return func(*config_arg) return func(**config_arg)
for i, config_arg in enumerate(config_args): for i, config_arg in enumerate(config_args):
future = pool.submit( future = pool.submit(
functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()), functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
*config_arg, **config_arg,
) )
futures.append(future) futures.append(future)
future_to_index[future] = i future_to_index[future] = i
...@@ -396,12 +383,12 @@ class AutoTuner: ...@@ -396,12 +383,12 @@ class AutoTuner:
ref_latency = None ref_latency = None
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar: for i in progress_bar:
jit_context, config = results_with_configs[i] jit_kernel, config = results_with_configs[i]
try: try:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread # Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context) # latency, ref_latency = target_fn(jit_kernel)
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_context) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
except TimeoutException: except TimeoutException:
logger.info( logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
...@@ -419,26 +406,43 @@ class AutoTuner: ...@@ -419,26 +406,43 @@ class AutoTuner:
if latency < best_latency: if latency < best_latency:
best_latency = latency best_latency = latency
best_config = config best_config = config
best_jit_context = jit_context best_kernel = jit_kernel
progress_bar.set_postfix({"best_latency": best_latency}) progress_bar.set_postfix({"best_latency": best_latency})
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown() pool.shutdown()
if best_jit_context is None: if best_kernel is None:
error_msg = ("Auto-tuning failed: No configuration successfully " error_msg = ("Auto-tuning failed: No configuration successfully "
"compiled and passed benchmarking/validation.") "compiled and passed benchmarking/validation.")
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
return AutotuneResult( best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result(
latency=best_latency,
config=best_config,
ref_latency=ref_latency,
)
autotuner_result = AutotuneResult(
latency=best_latency, latency=best_latency,
config=best_config, config=best_config,
ref_latency=ref_latency, ref_latency=ref_latency,
libcode=best_jit_context.kernel.get_kernel_source(), libcode=best_kernel.get_kernel_source(),
func=self.fn(*best_config), func=best_kernel.prim_func,
kernel=best_jit_context.kernel) kernel=best_kernel)
if self.compile_args.execution_backend == "dlpack":
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result
return autotuner_result
def __call__(self) -> Any: def __call__(self) -> Any:
"""Make the AutoTuner callable, running the auto-tuning process. """Make the AutoTuner callable, running the auto-tuning process.
...@@ -449,116 +453,121 @@ class AutoTuner: ...@@ -449,116 +453,121 @@ class AutoTuner:
return self.run() return self.run()
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> AutotuneResult: class _AutoTunerImplementation:
"""Decorator for auto-tuning tilelang programs. # Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
Args: warmup: int = 25
configs: Configuration space to explore during auto-tuning. rep: int = 100
warmup: Number of warmup iterations before timing. timeout: int = 100
rep: Number of repetitions for timing measurements. configs: Any = None
timeout: Maximum time (in seconds) allowed for each configuration.
Returns: def __init__(self, configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> None:
Callable: Decorated function that performs auto-tuning. """Initialize the AutoTunerImplementation.
"""
def decorator(fn: Callable) -> AutoTuner: Args:
autotuner = AutoTuner(fn, configs=configs) configs: Configuration space to explore during auto-tuning.
autotuner.jit_compile = fn warmup: Number of warmup iterations before timing.
autotuner.run = partial(autotuner.run, warmup, rep, timeout) rep: Number of repetitions for timing measurements.
return autotuner timeout: Maximum time (in seconds) allowed for each configuration.
"""
self.configs = configs
self.warmup = warmup
self.rep = rep
self.timeout = timeout
return decorator self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {}
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
...
def jit(out_idx: Optional[List[int]] = None, @overload
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
ref_prog: Callable = None, ...
supply_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
max_mismatched_ratio: float = 0.01,
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = True,
target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
"""Just-In-Time compilation decorator for tilelang programs.
Args: # Actual implementation of __call__
out_idx: List of output tensor indices. def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided. warmup = self.warmup
ref_prog: Reference program for correctness validation. rep = self.rep
supply_prog: Supply program for input tensors. timeout = self.timeout
rtol: Relative tolerance for output validation. configs = self.configs
atol: Absolute tolerance for output validation.
max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
skip_check: Whether to skip validation checks.
manual_check_prog: Manual check program for validation.
cache_input_tensors: Whether to cache input tensors for each compilation.
target: Target platform ('auto', 'cuda', or 'hip').
Returns:
Callable: Decorated function that performs JIT compilation.
"""
# If a custom `supply_prog`` is provided, the profiler's `supply_type` setting @functools.wraps(fn)
# becomes ineffective. The custom supply program will be used instead. def wrapper(*args, **kwargs):
if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
logger.warning("Ignoring `supply_type` passed to `autotune.jit` because "
"`supply_prog` is not None.")
def wrapper(fn: Callable): key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
@wraps(fn) if key not in self._tuner_cache:
def decorator(*args, **kwargs) -> float:
kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target) def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
return JITContext( autotuner = AutoTuner(fn, configs=configs)
out_idx=out_idx, autotuner.jit_compile = jit_compile
ref_prog=ref_prog, autotuner.run = partial(autotuner.run, warmup, rep, timeout)
supply_prog=supply_prog,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
kernel=kernel,
supply_type=supply_type,
target=target)
return decorator artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return wrapper return self._tuner_cache[key]
return wrapper
def check_tensor_list_compatibility(
list1: List[torch.Tensor],
list2: List[torch.Tensor],
) -> bool:
"""Checks if two lists of tensors are compatible.
Compatibility checks performed include:
1. Lists have the same length.
2. Corresponding tensors have the same shape.
Args: def autotune( # This is the new public interface
list1: First list of tensors. func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
list2: Second list of tensors. *, # Indicates subsequent arguments are keyword-only
configs: Any,
warmup: int = 25,
rep: int = 100,
timeout: int = 100):
""" """
if len(list1) != len(list2): Just-In-Time (JIT) compiler decorator for TileLang functions.
return False
This decorator can be used without arguments (e.g., `@tilelang.jit`):
return all(tensor1.shape == tensor2.shape for tensor1, tensor2 in zip(list1, list2)) Applies JIT compilation with default settings.
Parameters
def get_available_cpu_count(): ----------
"""Gets the number of CPU cores available to the current process. func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional
Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional
Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None.
Returns
-------
Callable
Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function.
""" """
try: if callable(func):
cpu_count = len(os.sched_getaffinity(0)) # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
except AttributeError: # This is a placeholder for a real auto tuner implementation
cpu_count = os.cpu_count() raise ValueError(
"Use tilelang.autotune to decorate func without arguments is not supported yet.")
return cpu_count elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _AutoTunerImplementation(
configs=configs, warmup=warmup, rep=rep, timeout=timeout)
return configured_decorator
"""The auto-tune parameters.
"""
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from typing import Callable, List, Literal, Any, Optional, Union, Dict
from dataclasses import dataclass
from pathlib import Path
from tilelang.jit import JITKernel
import cloudpickle
import os
import shutil
from tilelang.engine.param import KernelParam
from tilelang import logger
import json
import hashlib
BEST_CONFIG_PATH = "best_config.json"
FUNCTION_PATH = "function.pkl"
LATENCY_PATH = "latency.json"
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
PARAMS_PATH = "params.pkl"
@dataclass(frozen=True)
class CompileArgs:
"""Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`.
Attributes:
out_idx: List of output tensor indices.
execution_backend: Execution backend to use for kernel execution (default: "cython").
target: Compilation target, either as a string or a TVM Target object (default: "auto").
target_host: Target host for cross-compilation (default: None).
verbose: Whether to enable verbose output (default: False).
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
"""
out_idx: Union[List[int], int] = -1
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
verbose: bool = False
pass_configs: Optional[Dict[str, Any]] = None
def compile_program(self, program: PrimFunc):
return tilelang.compile(
program,
out_idx=self.out_idx,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs)
def __hash__(self):
data = {
"out_idx":
self.out_idx,
"execution_backend":
self.execution_backend,
"target":
self.target,
"target_host":
str(self.target_host) if self.target_host else None,
"verbose":
self.verbose,
"pass_configs":
json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
@dataclass(frozen=True)
class ProfileArgs:
"""Profile arguments for the auto-tuner.
Attributes:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
out_idx: Union[List[int], int] = -1
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
"""
warmup: int = 25
rep: int = 100
timeout: int = 30
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
def __hash__(self):
data = {
"warmup": self.warmup,
"rep": self.rep,
"timeout": self.timeout,
"supply_type": str(self.supply_type),
"rtol": self.rtol,
"atol": self.atol,
"max_mismatched_ratio": self.max_mismatched_ratio,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
@dataclass(frozen=True)
class AutotuneResult:
"""Results from auto-tuning process.
Attributes:
latency: Best achieved execution latency.
config: Configuration that produced the best result.
ref_latency: Reference implementation latency.
libcode: Generated library code.
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel):
"""
Persists a compiled kernel to disk cache.
Args:
key (str): The hash key identifying the kernel.
kernel (JITKernel): The compiled kernel to be saved.
func (Callable, optional): The original function.
Note:
Saves the following files:
- kernel.cu: The compiled kernel source code
- wrapped_kernel.cu: The wrapped kernel source code
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
"""
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.adapter.get_kernel_source())
except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}")
# Save kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(
self,
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
"""
Loads a previously compiled kernel from disk cache.
Args:
key (str): The hash key identifying the kernel.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
Returns:
JITKernel: The loaded kernel if found, None otherwise.
"""
if not os.path.exists(cache_path):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "r") as f:
kernel_global_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
target_host=target_host,
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
)
else:
return None
def save_to_disk(self, path: Path):
if not os.path.exists(path):
os.makedirs(path)
# save best config
with open(path / BEST_CONFIG_PATH, "w") as f:
json.dump(self.config, f)
# save function
with open(path / FUNCTION_PATH, "wb") as f:
cloudpickle.dump(self.func, f)
# save ref latency
with open(path / LATENCY_PATH, "w") as f:
json.dump({
"latency": self.latency,
"ref_latency": self.ref_latency,
}, f)
# save kernel
self._save_kernel_to_disk(path, self.kernel)
@classmethod
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
if not os.path.exists(path):
return None
# load best config
with open(path / BEST_CONFIG_PATH, "r") as f:
config = json.load(f)
# load function
with open(path / FUNCTION_PATH, "rb") as f:
func = cloudpickle.load(f)
# load latency
with open(path / LATENCY_PATH, "r") as f:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]
kernel = cls._load_kernel_from_disk(cls, path, compile_args.target,
compile_args.target_host, compile_args.out_idx,
compile_args.execution_backend,
compile_args.pass_configs, func)
if kernel is None:
return None
kernel.update_tuner_result(
config=config,
latency=latency,
ref_latency=ref_latency,
)
result = cls(
config=config,
func=func,
kernel=kernel,
libcode=kernel.get_kernel_source(),
latency=latency,
ref_latency=ref_latency,
)
return result
...@@ -174,17 +174,8 @@ class KernelCache: ...@@ -174,17 +174,8 @@ class KernelCache:
if execution_backend == "dlpack": if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.") self.logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: # enter critical section again to check and update disk cache with self._lock:
disk_kernel = self._load_kernel_from_disk( if is_cache_enabled():
key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func) self._save_kernel_to_disk(key, kernel, func)
# Store in memory cache after compilation # Store in memory cache after compilation
......
"""The cache utils with class and database persistence - KernelCache Class"""
import os
import json
import shutil
from pathlib import Path
from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang.engine.param import KernelParam
import threading
import cloudpickle
import logging
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.version import __version__
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
PARAMS_PATH = "params.pkl"
class AutoTunerCache:
"""
Caches compiled kernels using a class and database persistence to avoid redundant compilation.
Cache files:
kernel.cu: The compiled kernel source code
wrapped_kernel.cu: The compiled wrapped kernel source code
kernel_lib.so: The compiled kernel library
params.pkl: The compiled kernel parameters
"""
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""
Implements singleton pattern for KernelCache class.
Args:
cache_dir (str): Directory path for storing kernel cache. Defaults to TILELANG_CACHE_DIR.
Returns:
KernelCache: The singleton instance of KernelCache.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None: # Double-checked locking
instance = super().__new__(cls)
instance.cache_dir = Path(cache_dir)
os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__)
instance.logger.setLevel(logging.ERROR)
instance._memory_cache = {} # Initialize memory cache
cls._instance = instance
return cls._instance
def _generate_key(
self,
func: Callable,
out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
pass_configs: dict = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
Args:
func (Callable): The function to be compiled.
out_idx (List[int]): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
args: Arguments passed to the function.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
Returns:
str: SHA256 hash key for the kernel configuration.
"""
func_binary = cloudpickle.dumps(func.script())
key_data = {
"version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple(
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
"target": str(target),
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
"pass_configs": pass_configs,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
def cached(
self,
func: PrimFunc = None,
out_idx: List[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
Args:
func: Function to be compiled or a prepared PrimFunc
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target_host: Host target platform
*args: Arguments passed to func
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
if not is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
key = self._generate_key(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
args=args,
target=target,
target_host=target_host,
pass_configs=pass_configs,
)
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.jit` instead of direct kernel caching.")
return self._memory_cache[key]
# Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func)
if kernel is not None:
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel
# Compile kernel if cache miss; leave critical section
kernel = JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock: # enter critical section again to check and update disk cache
disk_kernel = self._load_kernel_from_disk(
key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func)
# Store in memory cache after compilation
self._memory_cache[key] = kernel
return kernel
def set_cache_dir(self, cache_dir: str):
"""
Sets the cache directory for the kernel cache.
"""
self.cache_dir = Path(cache_dir)
def get_cache_dir(self) -> Path:
"""
Gets the cache directory for the kernel cache.
"""
return self.cache_dir
def clear_cache(self):
"""
Clears the entire kernel cache, including both in-memory and disk cache.
"""
with self._lock:
self._memory_cache.clear() # Clear in-memory cache
self._clear_disk_cache() # Clear disk cache
def _get_cache_path(self, key: str) -> str:
"""
Gets the filesystem path for a cached kernel.
Args:
key (str): The hash key identifying the kernel.
Returns:
str: Absolute path to the cache directory for this kernel.
"""
return os.path.join(self.cache_dir, key)
def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None):
"""
Persists a compiled kernel to disk cache.
Args:
key (str): The hash key identifying the kernel.
kernel (JITKernel): The compiled kernel to be saved.
func (Callable, optional): The original function.
Note:
Saves the following files:
- kernel.cu: The compiled kernel source code
- wrapped_kernel.cu: The wrapped kernel source code
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
"""
cache_path = self._get_cache_path(key)
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.adapter.get_kernel_source())
except Exception as e:
self.logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
# Save kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(
self,
key: str,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
"""
Loads a previously compiled kernel from disk cache.
Args:
key (str): The hash key identifying the kernel.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
Returns:
JITKernel: The loaded kernel if found, None otherwise.
"""
cache_path = self._get_cache_path(key)
if not os.path.exists(cache_path):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "r") as f:
kernel_global_source = f.read()
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
target_host=target_host,
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
)
else:
return None
def _clear_disk_cache(self):
"""
Removes all cached kernels from disk.
Note:
This operation will delete the entire cache directory and recreate it empty.
Use with caution as this operation cannot be undone.
"""
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir) # Delete entire cache directory
os.makedirs(self.cache_dir, exist_ok=True) # Re-create cache directory
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
...@@ -133,6 +133,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -133,6 +133,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
# Global Barrier Synchronization must be applied before
# SplitHostDevice pass, as the global barrier
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
...@@ -144,7 +147,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -144,7 +147,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
else: else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod) mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod) mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
......
...@@ -10,13 +10,11 @@ from typing import ( ...@@ -10,13 +10,11 @@ from typing import (
Union, Union,
Callable, Callable,
Tuple, Tuple,
TypeVar,
overload, overload,
Literal, Literal,
Dict, # For type hinting dicts Dict, # For type hinting dicts
Optional, Optional,
) )
from typing_extensions import ParamSpec
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm.target import Target from tvm.target import Target
...@@ -26,6 +24,7 @@ from tilelang.cache import cached ...@@ -26,6 +24,7 @@ from tilelang.cache import cached
from os import path, makedirs from os import path, makedirs
from logging import getLogger from logging import getLogger
import functools import functools
from tilelang.jit.param import Kernel, _P, _RProg
logger = getLogger(__name__) logger = getLogger(__name__)
...@@ -77,71 +76,16 @@ def compile( ...@@ -77,71 +76,16 @@ def compile(
) )
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
class _JitImplementation: class _JitImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[True]) -> None:
...
@overload out_idx: Any
def __init__(self, target: Union[str, Target]
out_idx: Any = None, target_host: Union[str, Target]
target: Union[str, Target] = "auto", execution_backend: Literal["dlpack", "ctypes", "cython"]
target_host: Union[str, Target] = None, verbose: bool
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", pass_configs: Optional[Dict[str, Any]]
verbose: bool = False, debug_root_path: Optional[str]
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[False] = False) -> None:
...
# Actual implementation of __init__
def __init__(self, def __init__(self,
out_idx: Any = None, out_idx: Any = None,
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
...@@ -149,9 +93,7 @@ class _JitImplementation: ...@@ -149,9 +93,7 @@ class _JitImplementation:
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None, debug_root_path: Optional[str] = None):
*,
return_program: bool = False):
""" """
Initializes the JIT compiler decorator. Initializes the JIT compiler decorator.
...@@ -183,10 +125,6 @@ class _JitImplementation: ...@@ -183,10 +125,6 @@ class _JitImplementation:
If None, no debug information is saved (default: None). If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root If a relative path is given, it's made absolute relative to the project root
or current working directory. or current working directory.
return_program : bool, optional
If True, the decorated function will return a tuple containing the
original program's result and the compiled kernel. If False, only the
compiled kernel is returned (default: False).
""" """
self.out_idx = out_idx self.out_idx = out_idx
self.execution_backend = execution_backend self.execution_backend = execution_backend
...@@ -194,7 +132,6 @@ class _JitImplementation: ...@@ -194,7 +132,6 @@ class _JitImplementation:
self.target_host = target_host self.target_host = target_host
self.verbose = verbose self.verbose = verbose
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.return_program = return_program # Stored from args
# Corrected debug_root_path handling # Corrected debug_root_path handling
self.debug_root_path = debug_root_path self.debug_root_path = debug_root_path
...@@ -204,14 +141,11 @@ class _JitImplementation: ...@@ -204,14 +141,11 @@ class _JitImplementation:
self.debug_root_path = path.join(base_path, self.debug_root_path) self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError: except NameError:
self.debug_root_path = path.abspath(self.debug_root_path) self.debug_root_path = path.abspath(self.debug_root_path)
# If debug_root_path was None initially, it remains None.
# Type hint the caches
self._program_cache: Dict[tuple, _RProg] = {}
self._kernel_cache: Dict[tuple, Kernel] = {} self._kernel_cache: Dict[tuple, Kernel] = {}
# Overload __call__ based on the value of self.return_program
# This tells the type checker what the *wrapper* function will return. # This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload @overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]: def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
... ...
...@@ -228,17 +162,20 @@ class _JitImplementation: ...@@ -228,17 +162,20 @@ class _JitImplementation:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
key_args_tuple = args key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items())) key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple) key = (key_args_tuple, key_kwargs_tuple)
if key not in self._program_cache: if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly # Ensure 'func' (the original user function) is used correctly
program_result_source = func program_result_source = func
if isinstance(program_result_source, PrimFunc): if isinstance(program_result_source, PrimFunc):
program_result = program_result_source program_result = program_result_source
elif callable(program_result_source): elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs) program_result = program_result_source(*args, **kwargs, **tune_params)
else: else:
raise ValueError(f"Invalid function type: {type(program_result_source)}") raise ValueError(f"Invalid function type: {type(program_result_source)}")
...@@ -262,16 +199,9 @@ class _JitImplementation: ...@@ -262,16 +199,9 @@ class _JitImplementation:
with open(path.join(self.debug_root_path, program_file), 'w') as f: with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f) print(program_result.script(), file=f)
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result self._kernel_cache[key] = kernel_result
cached_program = self._program_cache[key] return self._kernel_cache[key]
cached_kernel = self._kernel_cache[key]
if self.return_program:
return cached_program, cached_kernel
else:
return cached_kernel
return wrapper return wrapper
...@@ -285,16 +215,12 @@ def jit( # This is the new public interface ...@@ -285,16 +215,12 @@ def jit( # This is the new public interface
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None, debug_root_path: Optional[str] = None):
return_program: bool = False):
""" """
Just-In-Time (JIT) compiler decorator for TileLang functions. Just-In-Time (JIT) compiler decorator for TileLang functions.
This decorator can be used in two ways: This decorator can be used without arguments (e.g., `@tilelang.jit`):
1. Without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings. Applies JIT compilation with default settings.
2. With arguments (e.g., `@tilelang.jit(target="cuda", return_program=True)`):
Configures the JIT compilation process with the specified options.
Parameters Parameters
---------- ----------
...@@ -314,9 +240,6 @@ def jit( # This is the new public interface ...@@ -314,9 +240,6 @@ def jit( # This is the new public interface
Configurations for TVM's pass context. Defaults to None. Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None. Directory to save compiled kernel source for debugging. Defaults to None.
return_program : bool, optional
If True, the decorated function returns a tuple (original program's result, compiled kernel).
Otherwise, only the compiled kernel is returned. Defaults to False.
Returns Returns
------- -------
...@@ -334,8 +257,7 @@ def jit( # This is the new public interface ...@@ -334,8 +257,7 @@ def jit( # This is the new public interface
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path, debug_root_path=debug_root_path)
return_program=return_program)
return default_decorator(func) return default_decorator(func)
elif isinstance(func, PrimFunc): elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
...@@ -350,6 +272,5 @@ def jit( # This is the new public interface ...@@ -350,6 +272,5 @@ def jit( # This is the new public interface
execution_backend=execution_backend, execution_backend=execution_backend,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
debug_root_path=debug_root_path, debug_root_path=debug_root_path)
return_program=return_program)
return configured_decorator return configured_decorator
...@@ -33,6 +33,11 @@ class JITKernel(object): ...@@ -33,6 +33,11 @@ class JITKernel(object):
adapter: BaseKernelAdapter = None adapter: BaseKernelAdapter = None
torch_function: Callable = None torch_function: Callable = None
# tuner result
latency: float = None
config: Dict[str, Any] = None
ref_latency: float = None
def __init__( def __init__(
self, self,
func: PrimFunc = None, func: PrimFunc = None,
...@@ -342,6 +347,51 @@ class JITKernel(object): ...@@ -342,6 +347,51 @@ class JITKernel(object):
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
def update_tuner_result(self, latency: float, config: Dict[str, Any],
ref_latency: float) -> "JITKernel":
"""
Updates the tuning results for this kernel.
Parameters
----------
latency : float
The measured latency of this kernel configuration.
config : Dict[str, Any]
The configuration parameters used for this kernel.
ref_latency : float
The reference latency to compare against.
Returns
-------
None
"""
self.latency = latency
self.config = config
self.ref_latency = ref_latency
return self
def get_tuner_result(self) -> Dict[str, Any]:
"""
Gets the tuning results for this kernel.
Returns
-------
Dict[str, Any]
A dictionary containing:
- latency: The measured latency of this kernel
- config: The configuration parameters used
- ref_latency: The reference latency for comparison
"""
if self.latency is None:
raise ValueError("Tuning results are not available. Please tune the kernel first.")
return {
"latency": self.latency,
"config": self.config,
"ref_latency": self.ref_latency,
}
@property @property
def out_idx(self) -> List[int]: def out_idx(self) -> List[int]:
return self.adapter.result_idx return self.adapter.result_idx
......
from typing import (
Any,
TypeVar,
)
from typing_extensions import ParamSpec
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
__all__ = ["Program", "Kernel", "_P", "_RProg"]
...@@ -98,13 +98,21 @@ class Profiler: ...@@ -98,13 +98,21 @@ class Profiler:
if isinstance(lib_outs, torch.Tensor): if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs] lib_outs = [lib_outs]
elif lib_outs is None:
lib_outs = []
if isinstance(ref_outs, torch.Tensor): if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs] ref_outs = [ref_outs]
elif ref_outs is None: elif ref_outs is None:
ref_outs = [] ref_outs = []
assert len(lib_outs) == len(ref_outs), "len(lib_outs) not equals to len(ref_outs) !"
ref_tensors = ins + ref_outs
lib_tensors = ins + lib_outs
assert len(lib_tensors) == len(
ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !"
# torch.set_printoptions(edgeitems=torch.inf) # torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_outs, ref_outs): for lhs, rhs in zip(lib_tensors, ref_tensors):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol) # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
# total_elements = lhs.numel() # total_elements = lhs.numel()
# num_not_close = (~close_mask).sum().item() # num_not_close = (~close_mask).sum().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment