Commit 0aaef97d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Add cache directory management functions in tilelang.cache (#453)

* [Feature] Add cache directory management functions in tilelang.cache

* Introduced `get_cache_dir` and `set_cache_dir` functions to manage the kernel cache directory.
* Updated `KernelCache` class to store cache directory as a `Path` object for improved path handling.
* Enhanced documentation with examples for new cache directory functions.

* [Refactor] Update cache imports in tilelang.__init__.py

* Added `set_cache_dir` and `get_cache_dir` functions to the import statement for improved cache directory management.
* This change enhances the accessibility of cache directory management functions within the module.
parent b1ba0cc8
......@@ -80,7 +80,7 @@ if SKIP_LOADING_TILELANG_SO == "0":
from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401
from .cache import cached, set_cache_dir, get_cache_dir # noqa: F401
from .utils import (
TensorSupplyType, # noqa: F401
......
"""The cache utils with class and database persistence - Init file"""
from typing import List, Union, Literal, Optional
from pathlib import Path
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
......@@ -36,6 +37,25 @@ def cached(
)
def get_cache_dir() -> Path:
"""
Gets the cache directory for the kernel cache.
Example:
>>> tilelang.cache.get_cache_dir()
PosixPath('/Users/username/.tilelang/cache')
"""
return _kernel_cache_instance.get_cache_dir()
def set_cache_dir(cache_dir: str):
"""
Sets the cache directory for the kernel cache.
Example:
>>> tilelang.cache.set_cache_dir("/path/to/cache")
"""
_kernel_cache_instance.set_cache_dir(cache_dir)
def clear_cache():
"""
Clears the entire kernel cache (using KernelCache class).
......
......@@ -3,6 +3,7 @@
import os
import json
import shutil
from pathlib import Path
from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional
from tvm.target import Target
......@@ -35,6 +36,8 @@ class KernelCache:
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""
Implements singleton pattern for KernelCache class.
......@@ -49,7 +52,7 @@ class KernelCache:
with cls._lock:
if cls._instance is None: # Double-checked locking
instance = super().__new__(cls)
instance.cache_dir = cache_dir
instance.cache_dir = Path(cache_dir)
os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__)
......@@ -184,6 +187,18 @@ class KernelCache:
self._memory_cache[key] = kernel
return kernel
def set_cache_dir(self, cache_dir: str):
"""
Sets the cache directory for the kernel cache.
"""
self.cache_dir = Path(cache_dir)
def get_cache_dir(self) -> Path:
"""
Gets the cache directory for the kernel cache.
"""
return self.cache_dir
def clear_cache(self):
"""
Clears the entire kernel cache, including both in-memory and disk cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment