Commit 5802c01b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Cache] Implement in-memory cache (#308)

* [Enhancement] Add support for CUDA architecture 8.9 in GEMM template

- Introduced conditional inclusion of "gemm_sm89.h" for CUDA architectures 8.9 and above, enhancing compatibility with newer hardware.
- This change ensures that the GEMM template can leverage optimizations specific to the 8.9 architecture, improving performance for users with compatible GPUs.

* lintfix

* [Refactor] Clean up includes in gemm_sm89.h

- Removed duplicate inclusion of "common.h" and added "cuda_fp8.h" for improved clarity and organization.
- This change enhances the maintainability of the code by ensuring that header files are included only once and in a logical order.

* [Enhancement] Improve KernelCache with in-memory caching and detailed docstrings

- Added an in-memory cache to the KernelCache class to enhance performance by reducing disk access.
- Updated the __new__ method to initialize the memory cache and added logic to check the cache before loading from disk.
- Enhanced docstrings across multiple methods to provide clearer explanations of parameters and return values, improving code readability and maintainability.
- Implemented a clear_cache method to clear both in-memory and disk caches, ensuring efficient cache management.

* lint fix
parent a2a32dea
...@@ -33,18 +33,28 @@ class KernelCache: ...@@ -33,18 +33,28 @@ class KernelCache:
_instance = None # For implementing singleton pattern _instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
def __new__(cls, cache_dir=TILELANG_CACHE_DIR): def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""Singleton pattern to ensure only one KernelCache instance""" """
Implements singleton pattern for KernelCache class.
Args:
cache_dir (str): Directory path for storing kernel cache. Defaults to TILELANG_CACHE_DIR.
Returns:
KernelCache: The singleton instance of KernelCache.
"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: # 双重检查锁定 if cls._instance is None: # Double-checked locking
instance = super().__new__(cls) instance = super().__new__(cls)
instance.cache_dir = cache_dir instance.cache_dir = cache_dir
os.makedirs(instance.cache_dir, exist_ok=True) os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__) instance.logger = logging.getLogger(__name__)
instance.logger.setLevel(logging.ERROR) instance.logger.setLevel(logging.ERROR)
instance._memory_cache = {} # Initialize memory cache
cls._instance = instance cls._instance = instance
return cls._instance return cls._instance
...@@ -57,7 +67,20 @@ class KernelCache: ...@@ -57,7 +67,20 @@ class KernelCache:
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
) -> str: ) -> str:
"""
Generates a unique hash key for caching compiled kernels.
Args:
func (Callable): The function to be compiled.
out_idx (List[int]): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
args: Arguments passed to the function.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
Returns:
str: SHA256 hash key for the kernel configuration.
"""
func_binary = cloudpickle.dumps(func.script()) func_binary = cloudpickle.dumps(func.script())
key_data = { key_data = {
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key "func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
...@@ -114,11 +137,17 @@ class KernelCache: ...@@ -114,11 +137,17 @@ class KernelCache:
args=args, args=args,
target=target, target=target,
target_host=target_host) target_host=target_host)
with self._lock: # TODO: use filelock with self._lock:
# Attempt to load from disk # First check in-memory cache
if key in self._memory_cache:
return self._memory_cache[key]
# Then check disk cache
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx, kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func) execution_backend, pass_configs, func)
if kernel is not None: if kernel is not None:
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel return kernel
# Compile kernel if cache miss; leave critical section # Compile kernel if cache miss; leave critical section
...@@ -146,6 +175,9 @@ class KernelCache: ...@@ -146,6 +175,9 @@ class KernelCache:
) )
if disk_kernel is None: if disk_kernel is None:
self._save_kernel_to_disk(key, kernel, func) self._save_kernel_to_disk(key, kernel, func)
# Store in memory cache after compilation
self._memory_cache[key] = kernel
return kernel return kernel
def clear_cache(self): def clear_cache(self):
...@@ -153,17 +185,36 @@ class KernelCache: ...@@ -153,17 +185,36 @@ class KernelCache:
Clears the entire kernel cache, including both in-memory and disk cache. Clears the entire kernel cache, including both in-memory and disk cache.
""" """
with self._lock: with self._lock:
self._memory_cache.clear() # Clear in-memory cache
self._clear_disk_cache() # Clear disk cache self._clear_disk_cache() # Clear disk cache
def _get_cache_path(self, key: str) -> str: def _get_cache_path(self, key: str) -> str:
""" """
Gets the cache file path for a given key. Gets the filesystem path for a cached kernel.
Args:
key (str): The hash key identifying the kernel.
Returns:
str: Absolute path to the cache directory for this kernel.
""" """
return os.path.join(self.cache_dir, key) return os.path.join(self.cache_dir, key)
def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None): def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None):
""" """
Saves the compiled kernel to disk. Persists a compiled kernel to disk cache.
Args:
key (str): The hash key identifying the kernel.
kernel (JITKernel): The compiled kernel to be saved.
func (Callable, optional): The original function.
Note:
Saves the following files:
- kernel.cu: The compiled kernel source code
- wrapped_kernel.cu: The wrapped kernel source code
- kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters
""" """
cache_path = self._get_cache_path(key) cache_path = self._get_cache_path(key)
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
...@@ -211,7 +262,19 @@ class KernelCache: ...@@ -211,7 +262,19 @@ class KernelCache:
func: Callable = None, func: Callable = None,
) -> JITKernel: ) -> JITKernel:
""" """
Loads kernel from disk. Loads a previously compiled kernel from disk cache.
Args:
key (str): The hash key identifying the kernel.
target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython".
pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function.
Returns:
JITKernel: The loaded kernel if found, None otherwise.
""" """
cache_path = self._get_cache_path(key) cache_path = self._get_cache_path(key)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
...@@ -254,7 +317,11 @@ class KernelCache: ...@@ -254,7 +317,11 @@ class KernelCache:
def _clear_disk_cache(self): def _clear_disk_cache(self):
""" """
Clears the cache directory on disk. Removes all cached kernels from disk.
Note:
This operation will delete the entire cache directory and recreate it empty.
Use with caution as this operation cannot be undone.
""" """
try: try:
if os.path.exists(self.cache_dir): if os.path.exists(self.cache_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment