Unverified Commit 653b873b authored by kk's avatar kk Committed by GitHub
Browse files

Fix cache modules of triton import error (#7832)

parent d379bda4
...@@ -83,12 +83,7 @@ from torch.func import functional_call ...@@ -83,12 +83,7 @@ from torch.func import functional_call
from torch.library import Library from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager from torch.utils._contextlib import _DecoratorContextManager
from triton.runtime.cache import ( from triton.runtime.cache import FileCacheManager
FileCacheManager,
default_cache_dir,
default_dump_dir,
default_override_dir,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -923,18 +918,41 @@ class CustomCacheManager(FileCacheManager): ...@@ -923,18 +918,41 @@ class CustomCacheManager(FileCacheManager):
self.key = key self.key = key
self.lock_path = None self.lock_path = None
try:
module_path = "triton.runtime.cache"
cache_module = importlib.import_module(module_path)
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
default_override_dir = getattr(cache_module, "default_override_dir", None)
except (ModuleNotFoundError, AttributeError) as e:
default_cache_dir = None
default_dump_dir = None
default_override_dir = None
if dump: if dump:
self.cache_dir = default_dump_dir() self.cache_dir = (
default_dump_dir()
if default_dump_dir is not None
else os.path.join(Path.home(), ".triton", "dump")
)
self.cache_dir = os.path.join(self.cache_dir, self.key) self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock") self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(self.cache_dir, exist_ok=True)
elif override: elif override:
self.cache_dir = default_override_dir() self.cache_dir = (
default_override_dir()
if default_override_dir is not None
else os.path.join(Path.home(), ".triton", "override")
)
self.cache_dir = os.path.join(self.cache_dir, self.key) self.cache_dir = os.path.join(self.cache_dir, self.key)
else: else:
# create cache directory if it doesn't exist # create cache directory if it doesn't exist
self.cache_dir = ( self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() default_cache_dir()
if default_cache_dir is not None
else os.path.join(Path.home(), ".triton", "cache")
) )
if self.cache_dir: if self.cache_dir:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment