Unverified Commit 72be4909 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor env into a more flexible version (#740)

* Fix environment variable name for compilation print setting in `env.py`

* Remove deprecated test file for warp specialized pass configuration and refactor environment variable access in `env.py` to utilize a centralized `EnvVar` class for better management and clarity.

* lint fix

* Refactor cache check to use `env.is_cache_enabled()` for consistency in `tuner.py`
parent e3a80b70
import tilelang
import os
def test_env_var():
# test default value
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
# test forced value
os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0"
# test forced value with class method
tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1"
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
if __name__ == "__main__":
test_env_var()
...@@ -53,8 +53,8 @@ _init_logger() ...@@ -53,8 +53,8 @@ _init_logger()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
from .env import env as env # noqa: F401
import tvm import tvm
import tvm.base import tvm.base
...@@ -76,12 +76,12 @@ def _load_tile_lang_lib(): ...@@ -76,12 +76,12 @@ def _load_tile_lang_lib():
# only load once here # only load once here
if SKIP_LOADING_TILELANG_SO == "0": if env.SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib() _LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel, compile # noqa: F401 from .jit import jit, JITKernel, compile # noqa: F401
from .profiler import Profiler # noqa: F401 from .profiler import Profiler # noqa: F401
from .cache import cached # noqa: F401 from .cache import clear_cache # noqa: F401
from .utils import ( from .utils import (
TensorSupplyType, # noqa: F401 TensorSupplyType, # noqa: F401
......
...@@ -25,13 +25,7 @@ import threading ...@@ -25,13 +25,7 @@ import threading
import traceback import traceback
from pathlib import Path from pathlib import Path
from tilelang.env import ( from tilelang import env
TILELANG_CACHE_DIR,
TILELANG_AUTO_TUNING_CPU_UTILITIES,
TILELANG_AUTO_TUNING_CPU_COUNTS,
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
is_cache_enabled,
)
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.autotuner.capture import get_autotune_inputs from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.jit.param import _P, _RProg from tilelang.jit.param import _P, _RProg
...@@ -111,7 +105,7 @@ class AutoTuner: ...@@ -111,7 +105,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner" cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -285,7 +279,7 @@ class AutoTuner: ...@@ -285,7 +279,7 @@ class AutoTuner:
key = self.generate_cache_key(parameters) key = self.generate_cache_key(parameters)
with self._lock: with self._lock:
if is_cache_enabled(): if env.is_cache_enabled():
# First check in-memory cache # First check in-memory cache
if key in self._memory_cache: if key in self._memory_cache:
logger.warning("Found kernel in memory cache. For better performance," \ logger.warning("Found kernel in memory cache. For better performance," \
...@@ -437,9 +431,9 @@ class AutoTuner: ...@@ -437,9 +431,9 @@ class AutoTuner:
return autotuner_result return autotuner_result
# get the cpu count # get the cpu count
available_cpu_count = get_available_cpu_count() available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES) cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES)
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS) cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS)
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT) max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
if cpu_counts > 0: if cpu_counts > 0:
num_workers = min(cpu_counts, available_cpu_count) num_workers = min(cpu_counts, available_cpu_count)
logger.info( logger.info(
...@@ -543,7 +537,7 @@ class AutoTuner: ...@@ -543,7 +537,7 @@ class AutoTuner:
logger.warning("DLPack backend does not support cache saving to disk.") logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: with self._lock:
if is_cache_enabled(): if env.is_cache_enabled():
self._save_result_to_disk(key, autotuner_result) self._save_result_to_disk(key, autotuner_result)
self._memory_cache[key] = autotuner_result self._memory_cache[key] = autotuner_result
......
...@@ -4,8 +4,8 @@ from typing import List, Union, Literal, Optional ...@@ -4,8 +4,8 @@ from typing import List, Union, Literal, Optional
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
from tilelang import env
from .kernel_cache import KernelCache from .kernel_cache import KernelCache
from tilelang.env import TILELANG_CLEAR_CACHE
# Create singleton instance of KernelCache # Create singleton instance of KernelCache
_kernel_cache_instance = KernelCache() _kernel_cache_instance = KernelCache()
...@@ -44,5 +44,5 @@ def clear_cache(): ...@@ -44,5 +44,5 @@ def clear_cache():
_kernel_cache_instance.clear_cache() _kernel_cache_instance.clear_cache()
if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"): if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
clear_cache() clear_cache()
...@@ -14,7 +14,7 @@ from tvm.target import Target ...@@ -14,7 +14,7 @@ from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled from tilelang import env
from tilelang.jit import JITKernel from tilelang.jit import JITKernel
from tilelang.version import __version__ from tilelang.version import __version__
...@@ -61,8 +61,8 @@ class KernelCache: ...@@ -61,8 +61,8 @@ class KernelCache:
@staticmethod @staticmethod
def _create_dirs(): def _create_dirs():
os.makedirs(TILELANG_CACHE_DIR, exist_ok=True) os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True)
os.makedirs(TILELANG_TMP_DIR, exist_ok=True) os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True)
def _generate_key( def _generate_key(
self, self,
...@@ -132,7 +132,7 @@ class KernelCache: ...@@ -132,7 +132,7 @@ class KernelCache:
Returns: Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache JITKernel: The compiled kernel, either freshly compiled or from cache
""" """
if not is_cache_enabled(): if not env.is_cache_enabled():
return JITKernel( return JITKernel(
func, func,
out_idx=out_idx, out_idx=out_idx,
...@@ -190,7 +190,7 @@ class KernelCache: ...@@ -190,7 +190,7 @@ class KernelCache:
self.logger.warning("DLPack backend does not support cache saving to disk.") self.logger.warning("DLPack backend does not support cache saving to disk.")
else: else:
with self._lock: with self._lock:
if is_cache_enabled(): if env.is_cache_enabled():
self._save_kernel_to_disk(key, kernel, func, verbose) self._save_kernel_to_disk(key, kernel, func, verbose)
# Store in memory cache after compilation # Store in memory cache after compilation
...@@ -215,7 +215,7 @@ class KernelCache: ...@@ -215,7 +215,7 @@ class KernelCache:
Returns: Returns:
str: Absolute path to the cache directory for this kernel. str: Absolute path to the cache directory for this kernel.
""" """
return os.path.join(TILELANG_CACHE_DIR, key) return os.path.join(env.TILELANG_CACHE_DIR, key)
@staticmethod @staticmethod
def _load_binary(path: str): def _load_binary(path: str):
...@@ -226,7 +226,7 @@ class KernelCache: ...@@ -226,7 +226,7 @@ class KernelCache:
@staticmethod @staticmethod
def _safe_write_file(path: str, mode: str, operation: Callable): def _safe_write_file(path: str, mode: str, operation: Callable):
# Random a temporary file within the same FS as the cache directory # Random a temporary file within the same FS as the cache directory
temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}") temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
with open(temp_path, mode) as temp_file: with open(temp_path, mode) as temp_file:
operation(temp_file) operation(temp_file)
...@@ -396,7 +396,7 @@ class KernelCache: ...@@ -396,7 +396,7 @@ class KernelCache:
""" """
try: try:
# Delete the entire cache directory # Delete the entire cache directory
shutil.rmtree(TILELANG_CACHE_DIR) shutil.rmtree(env.TILELANG_CACHE_DIR)
# Re-create the cache directory # Re-create the cache directory
KernelCache._create_dirs() KernelCache._create_dirs()
......
...@@ -6,7 +6,7 @@ from __future__ import absolute_import as _abs ...@@ -6,7 +6,7 @@ from __future__ import absolute_import as _abs
import os import os
import subprocess import subprocess
import warnings import warnings
from ..env import CUDA_HOME from tilelang.env import CUDA_HOME
import tvm.ffi import tvm.ffi
from tvm.target import Target from tvm.target import Target
......
...@@ -4,9 +4,21 @@ import pathlib ...@@ -4,9 +4,21 @@ import pathlib
import logging import logging
import shutil import shutil
import glob import glob
from dataclasses import dataclass
from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
def _find_cuda_home() -> str: def _find_cuda_home() -> str:
"""Find the CUDA install path. """Find the CUDA install path.
...@@ -46,76 +58,200 @@ def _find_rocm_home() -> str: ...@@ -46,76 +58,200 @@ def _find_rocm_home() -> str:
return rocm_home if rocm_home is not None else "" return rocm_home if rocm_home is not None else ""
def _initialize_torch_cuda_arch_flags(): # Cache control
import os class CacheState:
from tilelang.contrib import nvcc """Class to manage global kernel caching state."""
from tilelang.utils.target import determine_target _enabled = True
target = determine_target(return_object=True) @classmethod
# create tmp source file for torch cpp extension def enable(cls):
compute_version = nvcc.get_target_compute_version(target) """Enable kernel caching globally."""
major, minor = nvcc.parse_compute_version(compute_version) cls._enabled = True
# set TORCH_CUDA_ARCH_LIST @classmethod
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home()
CUTLASS_INCLUDE_DIR: str = os.environ.get("TL_CUTLASS_PATH", None) @dataclass
COMPOSABLE_KERNEL_INCLUDE_DIR: str = os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) class EnvVar:
TVM_PYTHON_PATH: str = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) """
TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None) Descriptor for managing access to a single environment variable.
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0] Purpose
-------
In many projects, access to environment variables is scattered across the codebase:
* `os.environ.get(...)` calls are repeated everywhere
* Default values are hard-coded in multiple places
* Overriding env vars for tests/debugging is messy
* There's no central place to see all environment variables a package uses
This descriptor solves those issues by:
1. Centralizing the definition of the variable's **key** and **default value**
2. Allowing *dynamic* reads from `os.environ` so changes take effect immediately
3. Supporting **forced overrides** at runtime (for unit tests or debugging)
4. Logging a warning when a forced value is used (helps detect unexpected overrides)
5. Optionally syncing forced values back to `os.environ` if global consistency is desired
How it works
------------
- This is a `dataclass` implementing the descriptor protocol (`__get__`, `__set__`)
- When used as a class attribute, `instance.attr` triggers `__get__()`
→ returns either the forced override or the live value from `os.environ`
- Assigning to the attribute (`instance.attr = value`) triggers `__set__()`
→ stores `_forced_value` for future reads
- You may uncomment the `os.environ[...] = value` line in `__set__` if you want
the override to persist globally in the process
Example
-------
```python
class Environment:
TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "0")
env = Environment()
print(cfg.TILELANG_PRINT_ON_COMPILATION) # Reads from os.environ (with default fallback)
cfg.TILELANG_PRINT_ON_COMPILATION = "1" # Forces value to "1" until changed/reset
```
Benefits
--------
* Centralizes all env-var keys and defaults in one place
* Live, up-to-date reads (no stale values after `import`)
* Testing convenience (override without touching the real env)
* Improves IDE discoverability and type hints
* Avoids hardcoding `os.environ.get(...)` in multiple places
"""
TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR", key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
os.path.expanduser("~/.tilelang/cache")) default: str # Default value if the environment variable is not set
TILELANG_TMP_DIR: str = os.path.join(TILELANG_CACHE_DIR, "tmp") _forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging)
# Print the kernel name on every compilation def get(self):
TILELANG_PRINT_ON_COMPILATION: str = os.environ.get("TILELANG_PRINT_COMPILATION", "0") if self._forced_value is not None:
return self._forced_value
return os.environ.get(self.key, self.default)
# Auto-clear cache if environment variable is set def __get__(self, instance, owner):
TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0") """
Called when the attribute is accessed.
1. If a forced value is set, return it and log a warning
2. Otherwise, look up the value in os.environ; return the default if missing
"""
return self.get()
# CPU Utilizations for Auto-Tuning, default is 0.9 def __set__(self, instance, value):
TILELANG_AUTO_TUNING_CPU_UTILITIES: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_UTILITIES", """
"0.9") Called when the attribute is assigned to.
Stores the value as a runtime override (forced value).
Optionally, you can also sync this into os.environ for global effect.
"""
self._forced_value = value
# Uncomment the following line if you want the override to persist globally:
# os.environ[self.key] = value
# CPU COUNTS for Auto-Tuning, default is -1,
# which will use TILELANG_AUTO_TUNING_CPU_UTILITIES * get_available_cpu_count()
TILELANG_AUTO_TUNING_CPU_COUNTS: str = os.environ.get("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1")
# Max CPU Count for Auto-Tuning, default is 100 # Cache control API (wrap CacheState)
TILELANG_AUTO_TUNING_MAX_CPU_COUNT: str = os.environ.get("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
"Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0") # Utility function for environment variables with defaults
# Assuming EnvVar and CacheState are defined elsewhere
class Environment:
"""
Environment configuration for TileLang.
Handles CUDA/ROCm detection, integration paths, template/cache locations,
auto-tuning configs, and build options.
"""
# CUDA/ROCm home directories
CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home()
# Path to the TileLang package root
TILELANG_PACKAGE_PATH = pathlib.Path(__file__).resolve().parent
# External library include paths
CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None)
# TVM integration
TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None)
# TileLang resources
TILELANG_TEMPLATE_PATH = EnvVar("TL_TEMPLATE_PATH", None)
TILELANG_CACHE_DIR = EnvVar("TILELANG_CACHE_DIR", os.path.expanduser("~/.tilelang/cache"))
TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp"))
# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path # Kernel Build options
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION",
"1") # print kernel name on compile
TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0") # clear cache automatically if set
if TVM_IMPORT_PYTHON_PATH is not None: # Auto-tuning settings
os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "") TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) "0.9") # percent of CPUs used
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS",
"-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT",
"-1") # -1 means no limit
# TVM integration
SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
def _initialize_torch_cuda_arch_flags(self) -> None:
"""
Detect target CUDA architecture and set TORCH_CUDA_ARCH_LIST
to ensure PyTorch extensions are built for the proper GPU arch.
"""
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True) # get target GPU
compute_version = nvcc.get_target_compute_version(target) # e.g. "8.6"
major, minor = nvcc.parse_compute_version(compute_version) # split to (8, 6)
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" # set env var for PyTorch
# Cache control API (wrap CacheState)
def is_cache_enabled(self) -> bool:
return CacheState.is_enabled()
def enable_cache(self) -> None:
CacheState.enable()
def disable_cache(self) -> None:
CacheState.disable()
# Instantiate as a global configuration object
env = Environment()
# Export CUDA_HOME and ROCM_HOME, both are static variables
# after initialization.
CUDA_HOME = env.CUDA_HOME
ROCM_HOME = env.ROCM_HOME
# Initialize TVM paths
if env.TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = env.TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, env.TVM_IMPORT_PYTHON_PATH)
else: else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = ( os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tvm_path + "/python") sys.path.insert(0, install_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python" env.TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python"
develop_tvm_path = os.path.join( develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
...@@ -123,7 +259,7 @@ else: ...@@ -123,7 +259,7 @@ else:
os.environ["PYTHONPATH"] = ( os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python") sys.path.insert(0, develop_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python" env.TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
develop_tvm_library_path = os.path.join( develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm") os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
...@@ -136,14 +272,15 @@ else: ...@@ -136,14 +272,15 @@ else:
else: else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE) logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
# pip install build library path # pip install build library path
lib_path = os.path.join(TILELANG_PACKAGE_PATH, "lib") lib_path = os.path.join(env.TILELANG_PACKAGE_PATH, "lib")
existing_path = os.environ.get("TVM_LIBRARY_PATH") existing_path = os.environ.get("TVM_LIBRARY_PATH")
if existing_path: if existing_path:
os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}" os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}"
else: else:
os.environ["TVM_LIBRARY_PATH"] = lib_path os.environ["TVM_LIBRARY_PATH"] = lib_path
TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None) env.TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
# Initialize CUTLASS paths
if os.environ.get("TL_CUTLASS_PATH", None) is None: if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join( install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
...@@ -151,13 +288,14 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None: ...@@ -151,13 +288,14 @@ if os.environ.get("TL_CUTLASS_PATH", None) is None:
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path): if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include" env.CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path): elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include" env.CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include"
else: else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE) logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
# Initialize COMPOSABLE_KERNEL paths
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
install_ck_path = os.path.join( install_ck_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel") os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel")
...@@ -165,63 +303,27 @@ if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None: ...@@ -165,63 +303,27 @@ if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel") os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel")
if os.path.exists(install_ck_path): if os.path.exists(install_ck_path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include" os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include" env.COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path): elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path):
os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include" os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include"
COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include" env.COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
else: else:
logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE) logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE)
# Initialize TL_TEMPLATE_PATH
if os.environ.get("TL_TEMPLATE_PATH", None) is None: if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src") install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src") develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
if os.path.exists(install_tl_template_path): if os.path.exists(install_tl_template_path):
os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
TILELANG_TEMPLATE_PATH = install_tl_template_path env.TILELANG_TEMPLATE_PATH = install_tl_template_path
elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path): elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
TILELANG_TEMPLATE_PATH = develop_tl_template_path env.TILELANG_TEMPLATE_PATH = develop_tl_template_path
else: else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Export static variables after initialization.
# Cache control CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR
class CacheState: COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR
"""Class to manage global kernel caching state.""" TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH
_enabled = True
@classmethod
def enable(cls):
"""Enable kernel caching globally."""
cls._enabled = True
@classmethod
def disable(cls):
"""Disable kernel caching globally."""
cls._enabled = False
@classmethod
def is_enabled(cls) -> bool:
"""Return current cache state."""
return cls._enabled
# Replace the old functions with class methods
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
__all__ = [
"CUTLASS_INCLUDE_DIR",
"COMPOSABLE_KERNEL_INCLUDE_DIR",
"TVM_PYTHON_PATH",
"TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH",
"CUDA_HOME",
"ROCM_HOME",
"TILELANG_CACHE_DIR",
"enable_cache",
"disable_cache",
"is_cache_enabled",
"_initialize_torch_cuda_arch_flags",
]
...@@ -4,9 +4,9 @@ from tvm.target import Target ...@@ -4,9 +4,9 @@ from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm
from tilelang import env
from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.env import TILELANG_PRINT_ON_COMPILATION
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter, from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
NVRTCKernelAdapter, TorchDLPackKernelAdapter) NVRTCKernelAdapter, TorchDLPackKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
...@@ -114,7 +114,7 @@ class JITKernel(object): ...@@ -114,7 +114,7 @@ class JITKernel(object):
# Print log on compilation starts # Print log on compilation starts
# NOTE(Chenggang): printing could let the training/inference framework easier to know # NOTE(Chenggang): printing could let the training/inference framework easier to know
# whether the communication timeout is from compilation # whether the communication timeout is from compilation
if TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"): if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"):
print(f"TileLang begins to compile kernel `{func.__name__}` with `{out_idx=}`") print(f"TileLang begins to compile kernel `{func.__name__}` with `{out_idx=}`")
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
......
...@@ -2,12 +2,12 @@ import os ...@@ -2,12 +2,12 @@ import os
import torch import torch
import warnings import warnings
from torch.utils.cpp_extension import load, _import_module_from_library from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR from tilelang import env
# Define paths # Define paths
compress_util = os.path.join(TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") compress_util = os.path.join(env.TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu")
# Cache directory for compiled extensions # Cache directory for compiled extensions
_CACHE_DIR = os.path.join(TILELANG_CACHE_DIR, "sparse_compressor") _CACHE_DIR = os.path.join(env.TILELANG_CACHE_DIR, "sparse_compressor")
os.makedirs(_CACHE_DIR, exist_ok=True) os.makedirs(_CACHE_DIR, exist_ok=True)
...@@ -22,9 +22,8 @@ def _get_cached_lib(): ...@@ -22,9 +22,8 @@ def _get_cached_lib():
# If loading fails, recompile # If loading fails, recompile
pass pass
from tilelang.env import _initialize_torch_cuda_arch_flags
# Set TORCH_CUDA_ARCH_LIST # Set TORCH_CUDA_ARCH_LIST
_initialize_torch_cuda_arch_flags() env._initialize_torch_cuda_arch_flags()
# Compile if not cached or loading failed # Compile if not cached or loading failed
return load( return load(
...@@ -34,8 +33,8 @@ def _get_cached_lib(): ...@@ -34,8 +33,8 @@ def _get_cached_lib():
'-O2', '-O2',
'-std=c++17', '-std=c++17',
'-lineinfo', '-lineinfo',
f'-I{CUTLASS_INCLUDE_DIR}', f'-I{env.CUTLASS_INCLUDE_DIR}',
f'-I{CUTLASS_INCLUDE_DIR}/../tools/util/include', f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',
'-arch=sm_90', '-arch=sm_90',
], ],
build_directory=_CACHE_DIR, build_directory=_CACHE_DIR,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment