Unverified Commit 721baedb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Fix autotune cache (#1315)

parent 470eb74c
......@@ -13,18 +13,25 @@ from pathlib import Path
from tilelang.jit import JITKernel
import cloudpickle
import os
import shutil
from tilelang.engine.param import KernelParam
from tilelang import logger
import json
import hashlib
import uuid
from tilelang import env
from tvm.runtime import Executable
BEST_CONFIG_PATH = "best_config.json"
FUNCTION_PATH = "function.pkl"
LATENCY_PATH = "latency.json"
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
# Align file names with cache/kernel_cache.py
DEVICE_KERNEL_PATH = "device_kernel.cu"
HOST_KERNEL_PATH = "host_kernel.cu"
EXECUTABLE_PATH = "executable.so"
KERNEL_LIB_PATH = "kernel_lib.so"
KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl"
......@@ -143,6 +150,31 @@ class AutotuneResult:
func: Callable | None = None
kernel: Callable | None = None
@staticmethod
def _load_binary(path: str):
with open(path, "rb") as file:
binary = file.read()
return binary
@staticmethod
def _safe_write_file(path: str, mode: str, operation: Callable[[Any], None]):
# Random a temporary file within the same FS as the cache directory
tmp_dir = env.TILELANG_TMP_DIR
os.makedirs(tmp_dir, exist_ok=True)
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}")
with open(temp_path, mode) as temp_file:
operation(temp_file)
# Use atomic POSIX replace, so other processes cannot see a partial write
os.replace(temp_path, path)
@staticmethod
def _safe_write_executable(executable: Executable, path: str):
tmp_dir = env.TILELANG_TMP_DIR
os.makedirs(tmp_dir, exist_ok=True)
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}.so")
executable.export_library(temp_path)
os.replace(temp_path, path)
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
"""
Persists a compiled kernel to disk cache.
......@@ -161,34 +193,68 @@ class AutotuneResult:
"""
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code
# Save device kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose:
logger.debug(f"Saving kernel source code to file: {kernel_path}")
logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None:
with open(kernel_path, "w") as f:
f.write(kernel.kernel_source)
self._safe_write_file(device_kernel_path, "w",
lambda f: f.write(kernel.kernel_source))
except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code
# Save host kernel source code (wrapped)
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
if verbose:
logger.debug(f"Saving wrapped kernel source code to file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "w") as f:
f.write(kernel.get_kernel_source())
logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
if kernel.execution_backend == "tvm_ffi":
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_host_source()))
else:
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_kernel_source()))
except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library
# Save kernel library (backend-specific)
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
shutil.copy(src_lib_path, kernel_lib_path)
if kernel.execution_backend == "nvrtc":
kernel_lib_file = KERNEL_CUBIN_PATH
elif kernel.execution_backend == "tvm_ffi":
kernel_lib_file = EXECUTABLE_PATH
else:
kernel_lib_file = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
if kernel.execution_backend == "nvrtc":
# Save cubin and python helper file
src_lib_path = kernel.adapter.libpath
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
py_src_path = src_lib_path.replace(".cubin", ".py")
if verbose:
logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
self._safe_write_file(kernel_py_path, "wb",
lambda f: f.write(self._load_binary(py_src_path)))
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
elif kernel.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
self._safe_write_executable(executable, kernel_lib_path)
else:
src_lib_path = kernel.adapter.libpath
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}")
......@@ -197,8 +263,7 @@ class AutotuneResult:
params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose:
logger.debug(f"Saving kernel parameters to disk: {params_path}")
with open(params_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f))
except Exception as e:
logger.error(f"Error saving kernel parameters to disk: {e}")
......@@ -210,6 +275,7 @@ class AutotuneResult:
out_idx: list[int] | int | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
func: Callable = None,
verbose: bool = False,
) -> JITKernel:
......@@ -233,23 +299,46 @@ class AutotuneResult:
if not os.path.exists(cache_path):
return None
kernel_global_source: str | None = None
# Resolve backend to pick correct file names
if execution_backend == "nvrtc":
kernel_lib_file = KERNEL_CUBIN_PATH
elif execution_backend == "tvm_ffi":
kernel_lib_file = EXECUTABLE_PATH
else:
kernel_lib_file = KERNEL_LIB_PATH
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
params_path = os.path.join(cache_path, PARAMS_PATH)
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None
device_kernel_source: str | None = None
host_kernel_source: str | None = None
kernel_params: list[KernelParam] | None = None
# Load optional device kernel source
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(device_kernel_path) as f:
device_kernel_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
logger.error(f"Error loading kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
# Load optional host kernel source
try:
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f:
host_kernel_source = f.read()
except Exception as e:
logger.error(f"Error loading host kernel source code from disk: {e}")
# Load kernel parameters
try:
params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose:
logger.debug(f"Loading kernel parameters from file: {params_path}")
with open(params_path, "rb") as f:
......@@ -257,10 +346,11 @@ class AutotuneResult:
except Exception as e:
logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params:
if host_kernel_source and device_kernel_source and kernel_params:
return JITKernel.from_database(
func=func,
kernel_global_source=kernel_global_source,
host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path,
params=kernel_params,
target=target,
......@@ -268,6 +358,7 @@ class AutotuneResult:
out_idx=out_idx,
execution_backend=execution_backend,
pass_configs=pass_configs,
compile_flags=compile_flags,
)
else:
return None
......@@ -276,26 +367,29 @@ class AutotuneResult:
if not os.path.exists(path):
os.makedirs(path)
# save best config
# save best config (atomic)
if verbose:
logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}")
with open(path / BEST_CONFIG_PATH, "w") as f:
json.dump(self.config, f)
self._safe_write_file(
str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
# save function
# save function (atomic)
if verbose:
logger.debug(f"Saving function to file: {path / FUNCTION_PATH}")
with open(path / FUNCTION_PATH, "wb") as f:
cloudpickle.dump(self.func, f)
self._safe_write_file(
str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
# save ref latency
# save ref latency (atomic)
if verbose:
logger.debug(f"Saving latency to file: {path / LATENCY_PATH}")
with open(path / LATENCY_PATH, "w") as f:
json.dump({
self._safe_write_file(
str(path / LATENCY_PATH),
"w",
lambda f: json.dump({
"latency": self.latency,
"ref_latency": self.ref_latency,
}, f)
}, f),
)
# save kernel
self._save_kernel_to_disk(path, self.kernel)
......@@ -306,6 +400,13 @@ class AutotuneResult:
return None
verbose = compile_args.verbose
# Normalize target and resolve execution backend for loading
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend
norm_target = Target(_determine_target(compile_args.target)) if isinstance(
compile_args.target, str) else compile_args.target
requested_backend = compile_args.execution_backend
resolved_backend = resolve_execution_backend(requested_backend, norm_target)
# load best config
if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
......@@ -325,10 +426,17 @@ class AutotuneResult:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]
kernel = cls._load_kernel_from_disk(cls, path, compile_args.target,
compile_args.target_host, compile_args.out_idx,
compile_args.execution_backend,
compile_args.pass_configs, func)
kernel = cls._load_kernel_from_disk(
cls,
path,
norm_target,
compile_args.target_host,
compile_args.out_idx,
resolved_backend,
compile_args.pass_configs,
None, # compile_flags not tracked here
func,
)
if kernel is None:
return None
kernel.update_tuner_result(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment