"configs/vscode:/vscode.git/clone" did not exist on "e9b7b8ab02b8f2cf5cbe7a51d77363d7fc6d49cb"
Unverified Commit 721baedb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Fix autotune cache (#1315)

parent 470eb74c
...@@ -13,18 +13,25 @@ from pathlib import Path ...@@ -13,18 +13,25 @@ from pathlib import Path
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
import cloudpickle import cloudpickle
import os import os
import shutil
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang import logger from tilelang import logger
import json import json
import hashlib import hashlib
import uuid
from tilelang import env
from tvm.runtime import Executable
BEST_CONFIG_PATH = "best_config.json" BEST_CONFIG_PATH = "best_config.json"
FUNCTION_PATH = "function.pkl" FUNCTION_PATH = "function.pkl"
LATENCY_PATH = "latency.json" LATENCY_PATH = "latency.json"
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu" # Align file names with cache/kernel_cache.py
DEVICE_KERNEL_PATH = "device_kernel.cu"
HOST_KERNEL_PATH = "host_kernel.cu"
EXECUTABLE_PATH = "executable.so"
KERNEL_LIB_PATH = "kernel_lib.so" KERNEL_LIB_PATH = "kernel_lib.so"
KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl" PARAMS_PATH = "params.pkl"
...@@ -143,6 +150,31 @@ class AutotuneResult: ...@@ -143,6 +150,31 @@ class AutotuneResult:
func: Callable | None = None func: Callable | None = None
kernel: Callable | None = None kernel: Callable | None = None
@staticmethod
def _load_binary(path: str):
with open(path, "rb") as file:
binary = file.read()
return binary
@staticmethod
def _safe_write_file(path: str, mode: str, operation: Callable[[Any], None]):
# Random a temporary file within the same FS as the cache directory
tmp_dir = env.TILELANG_TMP_DIR
os.makedirs(tmp_dir, exist_ok=True)
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}")
with open(temp_path, mode) as temp_file:
operation(temp_file)
# Use atomic POSIX replace, so other processes cannot see a partial write
os.replace(temp_path, path)
@staticmethod
def _safe_write_executable(executable: Executable, path: str):
tmp_dir = env.TILELANG_TMP_DIR
os.makedirs(tmp_dir, exist_ok=True)
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}.so")
executable.export_library(temp_path)
os.replace(temp_path, path)
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False): def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
""" """
Persists a compiled kernel to disk cache. Persists a compiled kernel to disk cache.
...@@ -161,34 +193,68 @@ class AutotuneResult: ...@@ -161,34 +193,68 @@ class AutotuneResult:
""" """
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code # Save device kernel source code
try: try:
kernel_path = os.path.join(cache_path, KERNEL_PATH) device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose: if verbose:
logger.debug(f"Saving kernel source code to file: {kernel_path}") logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None: if kernel.kernel_source is not None:
with open(kernel_path, "w") as f: self._safe_write_file(device_kernel_path, "w",
f.write(kernel.kernel_source) lambda f: f.write(kernel.kernel_source))
except Exception as e: except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}") logger.error(f"Error saving kernel source code to disk: {e}")
# Save wrapped kernel source code # Save host kernel source code (wrapped)
try: try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
if verbose: if verbose:
logger.debug(f"Saving wrapped kernel source code to file: {wrapped_kernel_path}") logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
with open(wrapped_kernel_path, "w") as f: # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
f.write(kernel.get_kernel_source()) if kernel.execution_backend == "tvm_ffi":
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_host_source()))
else:
self._safe_write_file(host_kernel_path, "w",
lambda f: f.write(kernel.adapter.get_kernel_source()))
except Exception as e: except Exception as e:
logger.error(f"Error saving wrapped kernel source code to disk: {e}") logger.error(f"Error saving wrapped kernel source code to disk: {e}")
# Save kernel library # Save kernel library (backend-specific)
try: try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH) if kernel.execution_backend == "nvrtc":
kernel_lib_file = KERNEL_CUBIN_PATH
elif kernel.execution_backend == "tvm_ffi":
kernel_lib_file = EXECUTABLE_PATH
else:
kernel_lib_file = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
if kernel.execution_backend == "nvrtc":
# Save cubin and python helper file
src_lib_path = kernel.adapter.libpath src_lib_path = kernel.adapter.libpath
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
py_src_path = src_lib_path.replace(".cubin", ".py")
if verbose:
logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
self._safe_write_file(kernel_py_path, "wb",
lambda f: f.write(self._load_binary(py_src_path)))
if verbose: if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}") logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
shutil.copy(src_lib_path, kernel_lib_path) self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
elif kernel.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
self._safe_write_executable(executable, kernel_lib_path)
else:
src_lib_path = kernel.adapter.libpath
if verbose:
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
self._safe_write_file(kernel_lib_path, "wb",
lambda f: f.write(self._load_binary(src_lib_path)))
except Exception as e: except Exception as e:
logger.error(f"Error saving kernel library to disk: {e}") logger.error(f"Error saving kernel library to disk: {e}")
...@@ -197,8 +263,7 @@ class AutotuneResult: ...@@ -197,8 +263,7 @@ class AutotuneResult:
params_path = os.path.join(cache_path, PARAMS_PATH) params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose: if verbose:
logger.debug(f"Saving kernel parameters to disk: {params_path}") logger.debug(f"Saving kernel parameters to disk: {params_path}")
with open(params_path, "wb") as f: self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f))
cloudpickle.dump(kernel.params, f)
except Exception as e: except Exception as e:
logger.error(f"Error saving kernel parameters to disk: {e}") logger.error(f"Error saving kernel parameters to disk: {e}")
...@@ -210,6 +275,7 @@ class AutotuneResult: ...@@ -210,6 +275,7 @@ class AutotuneResult:
out_idx: list[int] | int | None = None, out_idx: list[int] | int | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
func: Callable = None, func: Callable = None,
verbose: bool = False, verbose: bool = False,
) -> JITKernel: ) -> JITKernel:
...@@ -233,23 +299,46 @@ class AutotuneResult: ...@@ -233,23 +299,46 @@ class AutotuneResult:
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
return None return None
kernel_global_source: str | None = None # Resolve backend to pick correct file names
if execution_backend == "nvrtc":
kernel_lib_file = KERNEL_CUBIN_PATH
elif execution_backend == "tvm_ffi":
kernel_lib_file = EXECUTABLE_PATH
else:
kernel_lib_file = KERNEL_LIB_PATH
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
params_path = os.path.join(cache_path, PARAMS_PATH)
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None
device_kernel_source: str | None = None
host_kernel_source: str | None = None
kernel_params: list[KernelParam] | None = None kernel_params: list[KernelParam] | None = None
# Load optional device kernel source
try: try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose: if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(wrapped_kernel_path) as f: with open(device_kernel_path) as f:
kernel_global_source = f.read() device_kernel_source = f.read()
except Exception as e: except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}") logger.error(f"Error loading kernel source code from disk: {e}")
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH) # Load optional host kernel source
try:
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f:
host_kernel_source = f.read()
except Exception as e:
logger.error(f"Error loading host kernel source code from disk: {e}")
# Load kernel parameters # Load kernel parameters
try: try:
params_path = os.path.join(cache_path, PARAMS_PATH)
if verbose: if verbose:
logger.debug(f"Loading kernel parameters from file: {params_path}") logger.debug(f"Loading kernel parameters from file: {params_path}")
with open(params_path, "rb") as f: with open(params_path, "rb") as f:
...@@ -257,10 +346,11 @@ class AutotuneResult: ...@@ -257,10 +346,11 @@ class AutotuneResult:
except Exception as e: except Exception as e:
logger.error(f"Error loading kernel parameters from disk: {e}") logger.error(f"Error loading kernel parameters from disk: {e}")
if kernel_global_source and kernel_params: if host_kernel_source and device_kernel_source and kernel_params:
return JITKernel.from_database( return JITKernel.from_database(
func=func, func=func,
kernel_global_source=kernel_global_source, host_kernel_source=host_kernel_source,
device_kernel_source=device_kernel_source,
kernel_lib_path=kernel_lib_path, kernel_lib_path=kernel_lib_path,
params=kernel_params, params=kernel_params,
target=target, target=target,
...@@ -268,6 +358,7 @@ class AutotuneResult: ...@@ -268,6 +358,7 @@ class AutotuneResult:
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
pass_configs=pass_configs, pass_configs=pass_configs,
compile_flags=compile_flags,
) )
else: else:
return None return None
...@@ -276,26 +367,29 @@ class AutotuneResult: ...@@ -276,26 +367,29 @@ class AutotuneResult:
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
# save best config # save best config (atomic)
if verbose: if verbose:
logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}") logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}")
with open(path / BEST_CONFIG_PATH, "w") as f: self._safe_write_file(
json.dump(self.config, f) str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
# save function # save function (atomic)
if verbose: if verbose:
logger.debug(f"Saving function to file: {path / FUNCTION_PATH}") logger.debug(f"Saving function to file: {path / FUNCTION_PATH}")
with open(path / FUNCTION_PATH, "wb") as f: self._safe_write_file(
cloudpickle.dump(self.func, f) str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
# save ref latency # save ref latency (atomic)
if verbose: if verbose:
logger.debug(f"Saving latency to file: {path / LATENCY_PATH}") logger.debug(f"Saving latency to file: {path / LATENCY_PATH}")
with open(path / LATENCY_PATH, "w") as f: self._safe_write_file(
json.dump({ str(path / LATENCY_PATH),
"w",
lambda f: json.dump({
"latency": self.latency, "latency": self.latency,
"ref_latency": self.ref_latency, "ref_latency": self.ref_latency,
}, f) }, f),
)
# save kernel # save kernel
self._save_kernel_to_disk(path, self.kernel) self._save_kernel_to_disk(path, self.kernel)
...@@ -306,6 +400,13 @@ class AutotuneResult: ...@@ -306,6 +400,13 @@ class AutotuneResult:
return None return None
verbose = compile_args.verbose verbose = compile_args.verbose
# Normalize target and resolve execution backend for loading
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend
norm_target = Target(_determine_target(compile_args.target)) if isinstance(
compile_args.target, str) else compile_args.target
requested_backend = compile_args.execution_backend
resolved_backend = resolve_execution_backend(requested_backend, norm_target)
# load best config # load best config
if verbose: if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}") logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
...@@ -325,10 +426,17 @@ class AutotuneResult: ...@@ -325,10 +426,17 @@ class AutotuneResult:
latency = json.load(f) latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"] latency, ref_latency = latency["latency"], latency["ref_latency"]
kernel = cls._load_kernel_from_disk(cls, path, compile_args.target, kernel = cls._load_kernel_from_disk(
compile_args.target_host, compile_args.out_idx, cls,
compile_args.execution_backend, path,
compile_args.pass_configs, func) norm_target,
compile_args.target_host,
compile_args.out_idx,
resolved_backend,
compile_args.pass_configs,
None, # compile_flags not tracked here
func,
)
if kernel is None: if kernel is None:
return None return None
kernel.update_tuner_result( kernel.update_tuner_result(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment