"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c015421cae0cd96974506262c3880f64950df7b1"
Commit 0aaef97d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Add cache directory management functions in tilelang.cache (#453)

* [Feature] Add cache directory management functions in tilelang.cache

* Introduced `get_cache_dir` and `set_cache_dir` functions to manage the kernel cache directory.
* Updated `KernelCache` class to store cache directory as a `Path` object for improved path handling.
* Enhanced documentation with examples for new cache directory functions.

* [Refactor] Update cache imports in tilelang.__init__.py

* Added `set_cache_dir` and `get_cache_dir` functions to the import statement for improved cache directory management.
* This change enhances the accessibility of cache directory management functions within the module.
parent b1ba0cc8
...@@ -80,7 +80,7 @@ if SKIP_LOADING_TILELANG_SO == "0": ...@@ -80,7 +80,7 @@ if SKIP_LOADING_TILELANG_SO == "0":
from .jit import jit, JITKernel, compile # noqa: F401 from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401 from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401 from .cache import cached, set_cache_dir, get_cache_dir # noqa: F401
from .utils import ( from .utils import (
TensorSupplyType, # noqa: F401 TensorSupplyType, # noqa: F401
......
"""The cache utils with class and database persistence - Init file""" """The cache utils with class and database persistence - Init file"""
from typing import List, Union, Literal, Optional from typing import List, Union, Literal, Optional
from pathlib import Path
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
...@@ -36,6 +37,25 @@ def cached( ...@@ -36,6 +37,25 @@ def cached(
) )
def get_cache_dir() -> Path:
"""
Gets the cache directory for the kernel cache.
Example:
>>> tilelang.cache.get_cache_dir()
PosixPath('/Users/username/.tilelang/cache')
"""
return _kernel_cache_instance.get_cache_dir()
def set_cache_dir(cache_dir: str):
"""
Sets the cache directory for the kernel cache.
Example:
>>> tilelang.cache.set_cache_dir("/path/to/cache")
"""
_kernel_cache_instance.set_cache_dir(cache_dir)
def clear_cache(): def clear_cache():
""" """
Clears the entire kernel cache (using KernelCache class). Clears the entire kernel cache (using KernelCache class).
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
import shutil import shutil
from pathlib import Path
from hashlib import sha256 from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional from typing import Callable, List, Literal, Union, Optional
from tvm.target import Target from tvm.target import Target
...@@ -35,6 +36,8 @@ class KernelCache: ...@@ -35,6 +36,8 @@ class KernelCache:
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __new__(cls, cache_dir=TILELANG_CACHE_DIR): def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
""" """
Implements singleton pattern for KernelCache class. Implements singleton pattern for KernelCache class.
...@@ -49,7 +52,7 @@ class KernelCache: ...@@ -49,7 +52,7 @@ class KernelCache:
with cls._lock: with cls._lock:
if cls._instance is None: # Double-checked locking if cls._instance is None: # Double-checked locking
instance = super().__new__(cls) instance = super().__new__(cls)
instance.cache_dir = cache_dir instance.cache_dir = Path(cache_dir)
os.makedirs(instance.cache_dir, exist_ok=True) os.makedirs(instance.cache_dir, exist_ok=True)
instance.logger = logging.getLogger(__name__) instance.logger = logging.getLogger(__name__)
...@@ -184,6 +187,18 @@ class KernelCache: ...@@ -184,6 +187,18 @@ class KernelCache:
self._memory_cache[key] = kernel self._memory_cache[key] = kernel
return kernel return kernel
def set_cache_dir(self, cache_dir: str):
"""
Sets the cache directory for the kernel cache.
"""
self.cache_dir = Path(cache_dir)
def get_cache_dir(self) -> Path:
"""
Gets the cache directory for the kernel cache.
"""
return self.cache_dir
def clear_cache(self): def clear_cache(self):
""" """
Clears the entire kernel cache, including both in-memory and disk cache. Clears the entire kernel cache, including both in-memory and disk cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment