"testing/python/language/test_tilelang_language_all_of.py" did not exist on "9a7a569dda862389c896ab8d7fb3aa9655219030"
Unverified Commit 72be4909 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor env into a more flexible version (#740)

* Fix environment variable name for compilation print setting in `env.py`

* Remove deprecated test file for warp specialized pass configuration and refactor environment variable access in `env.py` to utilize a centralized `EnvVar` class for better management and clarity.

* lint fix

* Refactor cache check to use `env.is_cache_enabled()` for consistency in `tuner.py`
parent e3a80b70
import tilelang
import os
def test_env_var():
# test default value
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
# test forced value
os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0"
# test forced value with class method
tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
if __name__ == "__main__":
test_env_var()
......@@ -53,8 +53,8 @@ _init_logger()
logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
from .env import env as env # noqa: F401
import tvm
import tvm.base
......@@ -76,12 +76,12 @@ def _load_tile_lang_lib():
# only load once here
if SKIP_LOADING_TILELANG_SO == "0":
if env.SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401
from .cache import clear_cache # noqa: F401
from .utils import (
TensorSupplyType, # noqa: F401
......
......@@ -25,13 +25,7 @@ import threading
import traceback
from pathlib import Path
from tilelang.env import (
TILELANG_CACHE_DIR,
TILELANG_AUTO_TUNING_CPU_UTILITIES,
TILELANG_AUTO_TUNING_CPU_COUNTS,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
is_cache_enabled,
)
from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.jit.param import _P, _RProg
......@@ -111,7 +105,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs):
self.fn = fn
......@@ -285,7 +279,7 @@ class AutoTuner:
key = self.generate_cache_key(parameters)
with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
# First check in-memory cache
if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \
......@@ -437,9 +431,9 @@ class AutoTuner:
return autotuner_result
# get the cpu count
available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count)
logger.info(
......@@ -543,7 +537,7 @@ class AutoTuner:
logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result
......
......@@ -4,8 +4,8 @@ from typing import List, Union, Literal, Optional
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang import env
from .kernel_cache import KernelCache
from tilelang.env import TILELANG_CLEAR_CACHE
# Create singleton instance of KernelCache
_kernel_cache_instance = KernelCache()
......@@ -44,5 +44,5 @@ def clear_cache():
_kernel_cache_instance.clear_cache()
if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
clear_cache()
......@@ -14,7 +14,7 @@ from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.engine.param import KernelParam
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled
from tilelang import env
from tilelang.jit import JITKernel
from tilelang.version import __version__
......@@ -61,8 +61,8 @@ class KernelCache:
@staticmethod
def _create_dirs():
os.makedirs(TILELANG_CACHE_DIR, exist_ok=True)
os.makedirs(TILELANG_TMP_DIR, exist_ok=True)
os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True)
os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True)
def _generate_key(
self,
......@@ -132,7 +132,7 @@ class KernelCache:
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
if not is_cache_enabled():
if not env.is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
......@@ -190,7 +190,7 @@ class KernelCache:
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
with self._lock:
if is_cache_enabled():
if env.is_cache_enabled():
self._save_kernel_to_disk(key, kernel, func, verbose)
# Store in memory cache after compilation
......@@ -215,7 +215,7 @@ class KernelCache:
Returns:
str: Absolute path to the cache directory for this kernel.
"""
return os.path.join(TILELANG_CACHE_DIR, key)
return os.path.join(env.TILELANG_CACHE_DIR, key)
@staticmethod
def _load_binary(path: str):
......@@ -226,7 +226,7 @@ class KernelCache:
@staticmethod
def _safe_write_file(path: str, mode: str, operation: Callable):
# Random a temporary file within the same FS as the cache directory
temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
with open(temp_path, mode) as temp_file:
operation(temp_file)
......@@ -396,7 +396,7 @@ class KernelCache:
"""
try:
# Delete the entire cache directory
shutil.rmtree(TILELANG_CACHE_DIR)
shutil.rmtree(env.TILELANG_CACHE_DIR)
# Re-create the cache directory
KernelCache._create_dirs()
......
......@@ -6,7 +6,7 @@ from __future__ import absolute_import as _abs
import os
import subprocess
import warnings
from ..env import CUDA_HOME
from tilelang.env import CUDA_HOME
import tvm.ffi
from tvm.target import Target
......
......@@ -4,9 +4,21 @@ import pathlib
import logging
import shutil
import glob
from dataclasses import dataclass
from typing import Optional
logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
def _find_cuda_home() -> str:
"""Find the CUDA install path.
......@@ -46,76 +58,200 @@ def _find_rocm_home() -> str:
return rocm_home if rocm_home is not None else ""
def _initialize_torch_cuda_arch_flags():
import os
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
# Cache control
class CacheState:
"""Class to manage global kernel caching state."""
_enabled = True
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = nvcc.get_target_compute_version(target)
major, minor = nvcc.parse_compute_version(compute_version)
@classmethod
def enable(cls):
"""Enable kernel caching globally."""
cls._enabled = True
# set TORCH_CUDA_ARCH_LIST
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
@classmethod
def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None)
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0]
@dataclass
class EnvVar:
"""
Descriptor for managing access to a single environment variable.
Purpose
-------
In many projects, access to environment variables is scattered across the codebase:
* `os.environ.get(...)` calls are repeated everywhere
* Default values are hard-coded in multiple places
* Overriding env vars for tests/debugging is messy
* There's no central place to see all environment variables a package uses
This descriptor solves those issues by:
1. Centralizing the definition of the variable's **key** and **default value**
2. Allowing *dynamic* reads from `os.environ` so changes take effect immediately
3. Supporting **forced overrides** at runtime (for unit tests or debugging)
4. Logging a warning when a forced value is used (helps detect unexpected overrides)
5. Optionally syncing forced values back to `os.environ` if global consistency is desired
How it works
------------
- This is a `dataclass` implementing the descriptor protocol (`__get__`, `__set__`)
- When used as a class attribute, `instance.attr` triggers `__get__()`
→ returns either the forced override or the live value from `os.environ`
- Assigning to the attribute (`instance.attr = value`) triggers `__set__()`
→ stores `_forced_value` for future reads
- You may uncomment the `os.environ[...] = value` line in `__set__` if you want
the override to persist globally in the process
Example
-------
```python
class Environment:
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "0")
env = Environment()
print(cfg.TILELANG_PRINT_ON_COMPILATION) # Reads from os.environ (with default fallback)
cfg.TILELANG_PRINT_ON_COMPILATION = "1" # Forces value to "1" until changed/reset
```
Benefits
--------
* Centralizes all env-var keys and defaults in one place
* Live, up-to-date reads (no stale values after `import`)
* Testing convenience (override without touching the real env)
* Improves IDE discoverability and type hints
* Avoids hardcoding `os.environ.get(...)` in multiple places
"""
TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR",
os.path.expanduser("~/.tilelang/cache"))
TILELANG_TMP_DIR: str = os.path.join(TILELANG_CACHE_DIR, "tmp")
key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
default: str # Default value if the environment variable is not set
_forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging)
# Print the kernel name on every compilation
TILELANG_PRINT_ON_COMPILATION: str = os.environ.get("TILELANG_PRINT_COMPILATION", "0")
def get(self):
if self._forced_value is not None:
return self._forced_value
return os.environ.get(self.key, self.default)
# Auto-clear cache if environment variable is set
TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0")
def __get__(self, instance, owner):
"""
Called when the attribute is accessed.
1. If a forced value is set, return it and log a warning
2. Otherwise, look up the value in os.environ; return the default if missing
"""
return self.get()
# CPU Utilizations for Auto-Tuning, default is 0.9
TILELANG_AUTO_TUNING_CPU_UTILITIES: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9")
def __set__(self, instance, value):
"""
Called when the attribute is assigned to.
Stores the value as a runtime override (forced value).
Optionally, you can also sync this into os.environ for global effect.
"""
self._forced_value = value
# Uncomment the following line if you want the override to persist globally:
# os.environ[self.key] = value
# CPU COUNTS for Auto-Tuning, default is -1,
# which will use TILELANG_AUTO_TUNING_CPU_UTILITIES * get_available_cpu_count()
TILELANG_AUTO_TUNING_CPU_COUNTS: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1")
# Max CPU Count for Auto-Tuning, default is 100
TILELANG_AUTO_TUNING_MAX_CPU_COUNT: str = os.environ.get("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1")
# Cache control API (wrap CacheState)
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
# Utility function for environment variables with defaults
# Assuming EnvVar and CacheState are defined elsewhere
class Environment:
"""
Environment configuration for TileLang.
Handles CUDA/ROCm detection, integration paths, template/cache locations,
auto-tuning configs, and build options.
"""
# CUDA/ROCm home directories
CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home()
# Path to the TileLang package root
TILELANG_PACKAGE_PATH = pathlib.Path(__file__).resolve().parent
# External library include paths
CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None)
# TVM integration
TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None)
# TileLang resources
TILELANG_TEMPLATE_PATH = EnvVar("TL_TEMPLATE_PATH", None)
TILELANG_CACHE_DIR = EnvVar("TILELANG_CACHE_DIR", os.path.expanduser("~/.tilelang/cache"))
TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp"))
# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
# Kernel Build options
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION",
"1") # print kernel name on compile
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH)
# Auto-tuning settings
TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
"0.9") # percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS",
"-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT",
"-1") # -1 means no limit
# TVM integration
SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
def _initialize_torch_cuda_arch_flags(self) -> None:
"""
Detect target CUDA architecture and set TORCH_CUDA_ARCH_LIST
to ensure PyTorch extensions are built for the proper GPU arch.
"""
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True) # get target GPU
compute_version = nvcc.get_target_compute_version(target) # e.g. "8.6"
major, minor = nvcc.parse_compute_version(compute_version) # split to (8, 6)
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" # set env var for PyTorch
# Cache control API (wrap CacheState)
def is_cache_enabled(self) -> bool:
return CacheState.is_enabled()
def enable_cache(self) -> None:
CacheState.enable()
def disable_cache(self) -> None:
CacheState.disable()
# Instantiate as a global configuration object
env = Environment()
# Export CUDA_HOME and ROCM_HOME, both are static variables
# after initialization.
CUDA_HOME = env.CUDA_HOME
ROCM_HOME = env.ROCM_HOME
# Initialize TVM paths
if env.TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = env.TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, env.TVM_IMPORT_PYTHON_PATH)
else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python"
env.TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python"
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
......@@ -123,7 +259,7 @@ else:
os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
env.TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
......@@ -136,14 +272,15 @@ else:
else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
# pip install build library path
lib_path = os.path.join(TILELANG_PACKAGE_PATH, "lib")
lib_path = os.path.join(env.TILELANG_PACKAGE_PATH, "lib")
existing_path = os.environ.get("TVM_LIBRARY_PATH")
if existing_path:
os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}"
else:
os.environ["TVM_LIBRARY_PATH"] = lib_path
TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
env.TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
# Initialize CUTLASS paths
if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
......@@ -151,13 +288,14 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include"
env.CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include"
env.CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
# Initialize COMPOSABLE_KERNEL paths
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
install_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel")
......@@ -165,63 +303,27 @@ if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel")
if os.path.exists(install_ck_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
env.COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
env.COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
else:
logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE)
# Initialize TL_TEMPLATE_PATH
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
if os.path.exists(install_tl_template_path):
os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
TILELANG_TEMPLATE_PATH = install_tl_template_path
env.TILELANG_TEMPLATE_PATH = install_tl_template_path
elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
TILELANG_TEMPLATE_PATH = develop_tl_template_path
env.TILELANG_TEMPLATE_PATH = develop_tl_template_path
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Cache control
class CacheState:
"""Class to manage global kernel caching state."""
_enabled = True
@classmethod
def enable(cls):
"""Enable kernel caching globally."""
cls._enabled = True
@classmethod
def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
# Replace the old functions with class methods
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
__all__ = [
"CUTLASS_INCLUDE_DIR",
"COMPOSABLE_KERNEL_INCLUDE_DIR",
"TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH",
"CUDA_HOME",
"ROCM_HOME",
"TILELANG_CACHE_DIR",
"enable_cache",
"disable_cache",
"is_cache_enabled",
"_initialize_torch_cuda_arch_flags",
]
# Export static variables after initialization.
CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR
COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR
TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH
......@@ -4,9 +4,9 @@ from tvm.target import Target
from tvm.tir import PrimFunc
import tilelang
from tilelang import tvm as tvm
from tilelang import tvm
from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.env import TILELANG_PRINT_ON_COMPILATION
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType
......@@ -114,7 +114,7 @@ class JITKernel(object):
# Print log on compilation starts
# NOTE(Chenggang): printing could let the training/inference framework easier to know
# whether the communication timeout is from compilation
if TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"):
if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"):
print(f"TileLang begins to compile kernel `{func.__name__}` with `{out_idx=}`")
# Compile the TileLang function and create a kernel adapter for execution.
......
......@@ -2,12 +2,12 @@ import os
import torch
import warnings
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
from tilelang import env
# Define paths
compress_util = os.path.join(TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu")
compress_util = os.path.join(env.TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu")
# Cache directory for compiled extensions
_CACHE_DIR = os.path.join(TILELANG_CACHE_DIR, "sparse_compressor")
_CACHE_DIR = os.path.join(env.TILELANG_CACHE_DIR, "sparse_compressor")
os.makedirs(_CACHE_DIR, exist_ok=True)
......@@ -22,9 +22,8 @@ def _get_cached_lib():
# If loading fails, recompile
pass
from tilelang.env import _initialize_torch_cuda_arch_flags
# Set TORCH_CUDA_ARCH_LIST
_initialize_torch_cuda_arch_flags()
env._initialize_torch_cuda_arch_flags()
# Compile if not cached or loading failed
return load(
......@@ -34,8 +33,8 @@ def _get_cached_lib():
'-O2',
'-std=c++17',
'-lineinfo',
f'-I{CUTLASS_INCLUDE_DIR}',
f'-I{CUTLASS_INCLUDE_DIR}/../tools/util/include',
f'-I{env.CUTLASS_INCLUDE_DIR}',
f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',
'-arch=sm_90',
],
build_directory=_CACHE_DIR,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment