Commit 7171aff6 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Autotune] Introduce cache mechanism for auto tuner (#527)

* [Enhancement] Add commit ID to versioning and improve logging initialization

* Updated `get_tilelang_version` to include an optional commit ID in the version string.
* Enhanced the `TileLangBuilPydCommand` to write the version with commit ID to the VERSION file during the build process.
* Introduced a new function `get_git_commit_id` in `version.py` to retrieve the current git commit hash.
* Refactored logger initialization in `autotuner/__init__.py` to ensure handlers are set up only once, improving performance and clarity.
* Minor fixes in `flatten_buffer.cc` and `kernel_cache.py` for better handling of versioning and logging.

* [Refactor] Enhance AutoTuner and JITKernel for improved performance and caching

* Refactored the AutoTuner class to include new methods for setting compilation and profiling arguments, enhancing configurability.
* Introduced caching mechanisms for tuning results, allowing for faster retrieval of previously computed configurations.
* Updated JITKernel to store tuning results, including latency and configuration details, improving the kernel's performance tracking.
* Added new methods for generating cache keys and saving/loading results to/from disk, streamlining the tuning process.
* Enhanced the overall structure and readability of the autotuning logic, ensuring better maintainability and clarity.
* Minor adjustments in related modules to support the new caching and profiling features.

* [Refactor] Clean up code formatting and improve readability in AutoTuner and related modules

* Consolidated import statements and removed unnecessary line breaks for better readability.
* Standardized function argument formatting across the AutoTuner and CompileArgs classes.
* Enhanced consistency in the use of whitespace and indentation throughout the codebase.
* Minor adjustments in the Profiler and JITKernel classes to improve clarity and maintainability.
* Ensured that all changes adhere to the project's coding style guidelines.

* [Refactor] Remove redundant type hints in AutoTuner modules

* Simplified import statements in `__init__.py` and `param.py` by removing unnecessary duplicate type hints for `Any`.
* Improved code readability and maintainability by streamlining type imports across the AutoTuner module.

* [Refactor] Update AutoTuner configuration for improved profiling and target detection

* Enhanced the AutoTuner configuration across multiple examples by adding `set_profile_args` to better manage profiling settings.
* Standardized the use of `target="auto"` in compile arguments to ensure automatic target detection.
* Removed redundant target specifications in certain instances to streamline the configuration process.
* Improved overall clarity and maintainability of the autotuning logic in various example scripts.

* [Refactor] Simplify code formatting and improve readability in example scripts

* Consolidated function argument formatting in `benchmark_mla_decode_amd_tilelang.py`, `example_elementwise_add.py`, and `performance.py` for better clarity.
* Removed unnecessary line breaks and standardized argument placement across multiple files.
* Enhanced overall code readability and maintainability in autotuning examples and performance scripts.

* [Refactor] Update JIT decorator usage across multiple files

* Removed redundant parameters from the JIT decorator in various benchmark and example scripts, simplifying the code.
* Standardized the import of the JIT decorator from `tilelang`, enhancing consistency across the codebase.
* Improved overall readability and maintainability by consolidating import statements and cleaning up function definitions.

* [Refactor] Standardize JIT decorator formatting across benchmark and example scripts

* Simplified the formatting of the JIT decorator in multiple files by removing unnecessary line breaks.
* Enhanced code readability and consistency in the usage of the JIT decorator across benchmark and example scripts.
* Improved overall maintainability by ensuring uniformity in function definitions and decorator usage.
parent 09581e4e
......@@ -93,10 +93,10 @@ from .layout import (
)
from . import (
transform, # noqa: F401
autotuner, # noqa: F401
language, # noqa: F401
engine, # noqa: F401
)
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
......
This diff is collapsed.
"""The auto-tune parameters.
"""
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from typing import Callable, List, Literal, Any, Optional, Union, Dict
from dataclasses import dataclass
from pathlib import Path
from tilelang.jit import JITKernel
import cloudpickle
import os
import shutil
from tilelang.engine.param import KernelParam
from tilelang import logger
import json
import hashlib
BEST_CONFIG_PATH = "best_config.json"
FUNCTION_PATH = "function.pkl"
LATENCY_PATH = "latency.json"
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
PARAMS_PATH = "params.pkl"
@dataclass(frozen=True)
class CompileArgs:
"""Compile arguments for the auto-tuner. Detailed description can be found in `tilelang.jit.compile`.
Attributes:
out_idx: List of output tensor indices.
execution_backend: Execution backend to use for kernel execution (default: "cython").
target: Compilation target, either as a string or a TVM Target object (default: "auto").
target_host: Target host for cross-compilation (default: None).
verbose: Whether to enable verbose output (default: False).
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
"""
out_idx: Union[List[int], int] = -1
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
verbose: bool = False
pass_configs: Optional[Dict[str, Any]] = None
def compile_program(self, program: PrimFunc):
return tilelang.compile(
program,
out_idx=self.out_idx,
target=self.target,
target_host=self.target_host,
verbose=self.verbose,
pass_configs=self.pass_configs)
def __hash__(self):
data = {
"out_idx":
self.out_idx,
"execution_backend":
self.execution_backend,
"target":
self.target,
"target_host":
str(self.target_host) if self.target_host else None,
"verbose":
self.verbose,
"pass_configs":
json.dumps(self.pass_configs, sort_keys=True) if self.pass_configs else None,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
@dataclass(frozen=True)
class ProfileArgs:
"""Profile arguments for the auto-tuner.
Attributes:
warmup: Number of warmup iterations.
rep: Number of repetitions for timing.
timeout: Maximum time per configuration.
supply_type: Type of tensor supply mechanism.
ref_prog: Reference program for correctness validation.
supply_prog: Supply program for input tensors.
out_idx: Union[List[int], int] = -1
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
"""
warmup: int = 25
rep: int = 100
timeout: int = 30
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
rtol: float = 1e-2
atol: float = 1e-2
max_mismatched_ratio: float = 0.01
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = True
def __hash__(self):
data = {
"warmup": self.warmup,
"rep": self.rep,
"timeout": self.timeout,
"supply_type": str(self.supply_type),
"rtol": self.rtol,
"atol": self.atol,
"max_mismatched_ratio": self.max_mismatched_ratio,
}
hash_obj = hashlib.sha256(json.dumps(data, sort_keys=True).encode('utf-8'))
return int.from_bytes(hash_obj.digest(), byteorder='big')
@dataclass(frozen=True)
class AutotuneResult:
"""Results from auto-tuning process.
Attributes:
latency: Best achieved execution latency.
config: Configuration that produced the best result.
ref_latency: Reference implementation latency.
libcode: Generated library code.
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: float
config: dict
ref_latency: float
libcode: str
func: Callable
kernel: Callable
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel):
"""
Persists a compiled kernel to disk cache.
Args:
key (str): The hash key identifying the kernel.
kernel (JITKernel): The compiled kernel to be saved.
func (Callable, optional): The original function.
Note:
Saves the following files:
- kernel.cu: The compiled kernel source code
- wrapped_kernel.cu: The wrapped kernel source code
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
"""
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.adapter.get_kernel_source())
except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}")
# Save kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(
self,
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
"""
Loads a previously compiled kernel from disk cache.
Args:
key (str): The hash key identifying the kernel.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
Returns:
JITKernel: The loaded kernel if found, None otherwise.
"""
if not os.path.exists(cache_path):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "r") as f:
kernel_global_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
target_host=target_host,
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
)
else:
return None
def save_to_disk(self, path: Path):
if not os.path.exists(path):
os.makedirs(path)
# save best config
with open(path / BEST_CONFIG_PATH, "w") as f:
json.dump(self.config, f)
# save function
with open(path / FUNCTION_PATH, "wb") as f:
cloudpickle.dump(self.func, f)
# save ref latency
with open(path / LATENCY_PATH, "w") as f:
json.dump({
"latency": self.latency,
"ref_latency": self.ref_latency,
}, f)
# save kernel
self._save_kernel_to_disk(path, self.kernel)
@classmethod
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
if not os.path.exists(path):
return None
# load best config
with open(path / BEST_CONFIG_PATH, "r") as f:
config = json.load(f)
# load function
with open(path / FUNCTION_PATH, "rb") as f:
func = cloudpickle.load(f)
# load latency
with open(path / LATENCY_PATH, "r") as f:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]
kernel = cls._load_kernel_from_disk(cls, path, compile_args.target,
compile_args.target_host, compile_args.out_idx,
compile_args.execution_backend,
compile_args.pass_configs, func)
if kernel is None:
return None
kernel.update_tuner_result(
config=config,
latency=latency,
ref_latency=ref_latency,
)
result = cls(
config=config,
func=func,
kernel=kernel,
libcode=kernel.get_kernel_source(),
latency=latency,
ref_latency=ref_latency,
)
return result
......@@ -174,17 +174,8 @@ class KernelCache:
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock: # enter critical section again to check and update disk cache
disk_kernel = self._load_kernel_from_disk(
key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None:
with self._lock:
if is_cache_enabled():
self._save_kernel_to_disk(key, kernel, func)
# Store in memory cache after compilation
......
"""The cache utils with class and database persistence - KernelCache Class"""
import os
import json
import shutil
from pathlib import Path
from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang.engine.param import KernelParam
import threading
import cloudpickle
import logging
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.version import __version__
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
PARAMS_PATH = "params.pkl"
class AutoTunerCache:
"""
Caches compiled kernels using a class and database persistence to avoid redundant compilation.
Cache files:
kernel.cu: The compiled kernel source code
wrapped_kernel.cu: The compiled wrapped kernel source code
kernel_lib.so: The compiled kernel library
params.pkl: The compiled kernel parameters
"""
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""
Implements singleton pattern for KernelCache class.
Args:
cache_dir (str): Directory path for storing kernel cache. Defaults to TILELANG_CACHE_DIR.
Returns:
KernelCache: The singleton instance of KernelCache.
"""
if cls._instance is None:
with cls._lock:
if cls._instance is None: # Double-checked locking
instance = super().__new__(cls)
instance.cache_dir = Path(cache_dir)
os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__)
instance.logger.setLevel(logging.ERROR)
instance._memory_cache = {} # Initialize memory cache
cls._instance = instance
return cls._instance
def _generate_key(
self,
func: Callable,
out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
pass_configs: dict = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
Args:
func (Callable): The function to be compiled.
out_idx (List[int]): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
args: Arguments passed to the function.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
Returns:
str: SHA256 hash key for the kernel configuration.
"""
func_binary = cloudpickle.dumps(func.script())
key_data = {
"version": __version__,
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
"args_repr": tuple(
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
"target": str(target),
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
"pass_configs": pass_configs,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
def cached(
self,
func: PrimFunc = None,
out_idx: List[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
Args:
func: Function to be compiled or a prepared PrimFunc
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target_host: Host target platform
*args: Arguments passed to func
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
if not is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
key = self._generate_key(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
args=args,
target=target,
target_host=target_host,
pass_configs=pass_configs,
)
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
self.logger.warning("Found kernel in memory cache. For better performance," \
" consider using `@tilelang.jit` instead of direct kernel caching.")
return self._memory_cache[key]
# Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func)
if kernel is not None:
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel
# Compile kernel if cache miss; leave critical section
kernel = JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock: # enter critical section again to check and update disk cache
disk_kernel = self._load_kernel_from_disk(
key,
target,
target_host,
out_idx,
execution_backend,
pass_configs,
func,
)
if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func)
# Store in memory cache after compilation
self._memory_cache[key] = kernel
return kernel
def set_cache_dir(self, cache_dir: str):
"""
Sets the cache directory for the kernel cache.
"""
self.cache_dir = Path(cache_dir)
def get_cache_dir(self) -> Path:
"""
Gets the cache directory for the kernel cache.
"""
return self.cache_dir
def clear_cache(self):
"""
Clears the entire kernel cache, including both in-memory and disk cache.
"""
with self._lock:
self._memory_cache.clear() # Clear in-memory cache
self._clear_disk_cache() # Clear disk cache
def _get_cache_path(self, key: str) -> str:
"""
Gets the filesystem path for a cached kernel.
Args:
key (str): The hash key identifying the kernel.
Returns:
str: Absolute path to the cache directory for this kernel.
"""
return os.path.join(self.cache_dir, key)
def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None):
"""
Persists a compiled kernel to disk cache.
Args:
key (str): The hash key identifying the kernel.
kernel (JITKernel): The compiled kernel to be saved.
func (Callable, optional): The original function.
Note:
Saves the following files:
- kernel.cu: The compiled kernel source code
- wrapped_kernel.cu: The wrapped kernel source code
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
"""
cache_path = self._get_cache_path(key)
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.adapter.get_kernel_source())
except Exception as e:
self.logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")
# Save kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(
self,
key: str,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
"""
Loads a previously compiled kernel from disk cache.
Args:
key (str): The hash key identifying the kernel.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
Returns:
JITKernel: The loaded kernel if found, None otherwise.
"""
cache_path = self._get_cache_path(key)
if not os.path.exists(cache_path):
return None
kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
with open(wrapped_kernel_path, "r") as f:
kernel_global_source = f.read()
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f)
except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
target_host=target_host,
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
)
else:
return None
def _clear_disk_cache(self):
"""
Removes all cached kernels from disk.
Note:
This operation will delete the entire cache directory and recreate it empty.
Use with caution as this operation cannot be undone.
"""
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir) # Delete entire cache directory
os.makedirs(self.cache_dir, exist_ok=True) # Re-create cache directory
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
......@@ -133,6 +133,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
# Global Barrier Synchronization must be applied before
# SplitHostDevice pass, as the global barrier
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
......@@ -144,7 +147,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.EliminateStorageSyncForMBarrier()(mod)
......
......@@ -10,13 +10,11 @@ from typing import (
Union,
Callable,
Tuple,
TypeVar,
overload,
Literal,
Dict, # For type hinting dicts
Optional,
)
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
......@@ -26,6 +24,7 @@ from tilelang.cache import cached
from os import path, makedirs
from logging import getLogger
import functools
from tilelang.jit.param import Kernel, _P, _RProg
logger = getLogger(__name__)
......@@ -77,71 +76,16 @@ def compile(
)
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
class _JitImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[True]) -> None:
...
@overload
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: Literal[False] = False) -> None:
...
out_idx: Any
target: Union[str, Target]
target_host: Union[str, Target]
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool
pass_configs: Optional[Dict[str, Any]]
debug_root_path: Optional[str]
# Actual implementation of __init__
def __init__(self,
out_idx: Any = None,
target: Union[str, Target] = "auto",
......@@ -149,9 +93,7 @@ class _JitImplementation:
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
*,
return_program: bool = False):
debug_root_path: Optional[str] = None):
"""
Initializes the JIT compiler decorator.
......@@ -183,10 +125,6 @@ class _JitImplementation:
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
return_program : bool, optional
If True, the decorated function will return a tuple containing the
original program's result and the compiled kernel. If False, only the
compiled kernel is returned (default: False).
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
......@@ -194,7 +132,6 @@ class _JitImplementation:
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.return_program = return_program # Stored from args
# Corrected debug_root_path handling
self.debug_root_path = debug_root_path
......@@ -204,14 +141,11 @@ class _JitImplementation:
self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
# If debug_root_path was None initially, it remains None.
# Type hint the caches
self._program_cache: Dict[tuple, _RProg] = {}
self._kernel_cache: Dict[tuple, Kernel] = {}
# Overload __call__ based on the value of self.return_program
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
...
......@@ -228,17 +162,20 @@ class _JitImplementation:
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
if key not in self._program_cache:
if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly
program_result_source = func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
program_result = program_result_source(*args, **kwargs, **tune_params)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
......@@ -262,16 +199,9 @@ class _JitImplementation:
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f)
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result
cached_program = self._program_cache[key]
cached_kernel = self._kernel_cache[key]
if self.return_program:
return cached_program, cached_kernel
else:
return cached_kernel
return self._kernel_cache[key]
return wrapper
......@@ -285,16 +215,12 @@ def jit( # This is the new public interface
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
return_program: bool = False):
debug_root_path: Optional[str] = None):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
This decorator can be used in two ways:
1. Without arguments (e.g., `@tilelang.jit`):
This decorator can be used without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
2. With arguments (e.g., `@tilelang.jit(target="cuda", return_program=True)`):
Configures the JIT compilation process with the specified options.
Parameters
----------
......@@ -314,9 +240,6 @@ def jit( # This is the new public interface
Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None.
return_program : bool, optional
If True, the decorated function returns a tuple (original program's result, compiled kernel).
Otherwise, only the compiled kernel is returned. Defaults to False.
Returns
-------
......@@ -334,8 +257,7 @@ def jit( # This is the new public interface
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
return_program=return_program)
debug_root_path=debug_root_path)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
......@@ -350,6 +272,5 @@ def jit( # This is the new public interface
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
return_program=return_program)
debug_root_path=debug_root_path)
return configured_decorator
......@@ -33,6 +33,11 @@ class JITKernel(object):
adapter: BaseKernelAdapter = None
torch_function: Callable = None
# tuner result
latency: float = None
config: Dict[str, Any] = None
ref_latency: float = None
def __init__(
self,
func: PrimFunc = None,
......@@ -342,6 +347,51 @@ class JITKernel(object):
def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func)
def update_tuner_result(self, latency: float, config: Dict[str, Any],
ref_latency: float) -> "JITKernel":
"""
Updates the tuning results for this kernel.
Parameters
----------
latency : float
The measured latency of this kernel configuration.
config : Dict[str, Any]
The configuration parameters used for this kernel.
ref_latency : float
The reference latency to compare against.
Returns
-------
None
"""
self.latency = latency
self.config = config
self.ref_latency = ref_latency
return self
def get_tuner_result(self) -> Dict[str, Any]:
"""
Gets the tuning results for this kernel.
Returns
-------
Dict[str, Any]
A dictionary containing:
- latency: The measured latency of this kernel
- config: The configuration parameters used
- ref_latency: The reference latency for comparison
"""
if self.latency is None:
raise ValueError("Tuning results are not available. Please tune the kernel first.")
return {
"latency": self.latency,
"config": self.config,
"ref_latency": self.ref_latency,
}
@property
def out_idx(self) -> List[int]:
return self.adapter.result_idx
......
from typing import (
Any,
TypeVar,
)
from typing_extensions import ParamSpec
# --- Mocking dependencies for the example to run ---
# In your actual code, these would be your real types.
class Program:
"""Placeholder for the type returned by the original decorated function."""
def __init__(self, data: str):
self.data = data
def __repr__(self):
return f"Program('{self.data}')"
class Kernel:
"""Placeholder for the type of the compiled kernel."""
def __init__(self, source: str, out_idx: Any):
self.source_code = source
self.out_idx = out_idx
def get_kernel_source(self) -> str:
return self.source_code
def __repr__(self):
return f"Kernel('{self.source_code[:20]}...')"
# --- End Mocking ---
# P (Parameters) captures the argument types of the decorated function.
_P = ParamSpec("_P")
# R_prog (Return type of Program) captures the return type of the original decorated function.
# We assume the original function returns something compatible with 'Program'.
_RProg = TypeVar("_RProg", bound=Program)
__all__ = ["Program", "Kernel", "_P", "_RProg"]
......@@ -98,13 +98,21 @@ class Profiler:
if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
elif lib_outs is None:
lib_outs = []
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
elif ref_outs is None:
ref_outs = []
assert len(lib_outs) == len(ref_outs), "len(lib_outs) not equals to len(ref_outs) !"
ref_tensors = ins + ref_outs
lib_tensors = ins + lib_outs
assert len(lib_tensors) == len(
ref_tensors), "len(lib_tensors) not equals to len(ref_tensors) !"
# torch.set_printoptions(edgeitems=torch.inf)
for lhs, rhs in zip(lib_outs, ref_outs):
for lhs, rhs in zip(lib_tensors, ref_tensors):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
# total_elements = lhs.numel()
# num_not_close = (~close_mask).sum().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment