env.py 14.5 KB
Newer Older
1
2
3
4
import sys
import os
import pathlib
import logging
5
6
import shutil
import glob
7
8
from dataclasses import dataclass
from typing import Optional
9
10
11

logger = logging.getLogger(__name__)

12
13
14
15
16
17
18
19
20
21
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE = (
    "Composable Kernel is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")

22
23
24

def _find_cuda_home() -> str:
    """Find the CUDA install path.
25

26
27
28
29
30
31
32
    Adapted from https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
    """
    # Guess #1
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
    if cuda_home is None:
        # Guess #2
        nvcc_path = shutil.which("nvcc")
33
34
35
36
37
38
39
40
41
42
43
44
        if nvcc_path is not None:
            # Standard CUDA pattern
            if "cuda" in nvcc_path.lower():
                cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
            # NVIDIA HPC SDK pattern
            elif "hpc_sdk" in nvcc_path.lower():
                # Navigate to the root directory of nvhpc
                cuda_home = os.path.dirname(os.path.dirname(os.path.dirname(nvcc_path)))
            # Generic fallback for non-standard or symlinked installs
            else:
                cuda_home = os.path.dirname(os.path.dirname(nvcc_path))

45
46
47
48
49
50
        else:
            # Guess #3
            if sys.platform == 'win32':
                cuda_homes = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
                cuda_home = '' if len(cuda_homes) == 0 else cuda_homes[0]
            else:
51
52
53
54
55
56
57
58
                # Linux/macOS
                if os.path.exists('/usr/local/cuda'):
                    cuda_home = '/usr/local/cuda'
                elif os.path.exists('/opt/nvidia/hpc_sdk/Linux_x86_64'):
                    cuda_home = '/opt/nvidia/hpc_sdk/Linux_x86_64'

            # Validate found path
            if cuda_home is None or not os.path.exists(cuda_home):
59
                cuda_home = None
60

61
62
63
    return cuda_home if cuda_home is not None else ""


64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _find_rocm_home() -> str:
    """Find the ROCM install path."""
    rocm_home = os.environ.get('ROCM_PATH') or os.environ.get('ROCM_HOME')
    if rocm_home is None:
        rocmcc_path = shutil.which("hipcc")
        if rocmcc_path is not None:
            rocm_home = os.path.dirname(os.path.dirname(rocmcc_path))
        else:
            rocm_home = '/opt/rocm'
            if not os.path.exists(rocm_home):
                rocm_home = None
    return rocm_home if rocm_home is not None else ""


78
79
80
81
# Cache control
class CacheState:
    """Class to manage global kernel caching state."""
    _enabled = True
82

83
84
85
86
    @classmethod
    def enable(cls):
        """Enable kernel caching globally."""
        cls._enabled = True
87

88
89
90
91
    @classmethod
    def disable(cls):
        """Disable kernel caching globally."""
        cls._enabled = False
92

93
94
95
96
    @classmethod
    def is_enabled(cls) -> bool:
        """Return current cache state."""
        return cls._enabled
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
@dataclass
class EnvVar:
    """
    Descriptor for managing access to a single environment variable.

    Purpose
    -------
    In many projects, access to environment variables is scattered across the codebase:
        * `os.environ.get(...)` calls are repeated everywhere
        * Default values are hard-coded in multiple places
        * Overriding env vars for tests/debugging is messy
        * There's no central place to see all environment variables a package uses

    This descriptor solves those issues by:
        1. Centralizing the definition of the variable's **key** and **default value**
        2. Allowing *dynamic* reads from `os.environ` so changes take effect immediately
        3. Supporting **forced overrides** at runtime (for unit tests or debugging)
        4. Logging a warning when a forced value is used (helps detect unexpected overrides)
        5. Optionally syncing forced values back to `os.environ` if global consistency is desired

    How it works
    ------------
    - This is a `dataclass` implementing the descriptor protocol (`__get__`, `__set__`)
    - When used as a class attribute, `instance.attr` triggers `__get__()`
        → returns either the forced override or the live value from `os.environ`
    - Assigning to the attribute (`instance.attr = value`) triggers `__set__()`
        → stores `_forced_value` for future reads
    - You may uncomment the `os.environ[...] = value` line in `__set__` if you want
      the override to persist globally in the process

    Example
    -------
    ```python
    class Environment:
        TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION", "0")

    env = Environment()
    print(cfg.TILELANG_PRINT_ON_COMPILATION)  # Reads from os.environ (with default fallback)
    cfg.TILELANG_PRINT_ON_COMPILATION = "1"   # Forces value to "1" until changed/reset
    ```

    Benefits
    --------
    * Centralizes all env-var keys and defaults in one place
    * Live, up-to-date reads (no stale values after `import`)
    * Testing convenience (override without touching the real env)
    * Improves IDE discoverability and type hints
    * Avoids hardcoding `os.environ.get(...)` in multiple places
    """
148

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    key: str  # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
    default: str  # Default value if the environment variable is not set
    _forced_value: Optional[str] = None  # Temporary runtime override (mainly for tests/debugging)

    def get(self):
        if self._forced_value is not None:
            return self._forced_value
        return os.environ.get(self.key, self.default)

    def __get__(self, instance, owner):
        """
        Called when the attribute is accessed.
        1. If a forced value is set, return it and log a warning
        2. Otherwise, look up the value in os.environ; return the default if missing
        """
        return self.get()

    def __set__(self, instance, value):
        """
        Called when the attribute is assigned to.
        Stores the value as a runtime override (forced value).
        Optionally, you can also sync this into os.environ for global effect.
        """
        self._forced_value = value
        # Uncomment the following line if you want the override to persist globally:
        # os.environ[self.key] = value


# Cache control API (wrap CacheState)
enable_cache = CacheState.enable
disable_cache = CacheState.disable
is_cache_enabled = CacheState.is_enabled
181
182


183
184
185
186
187
188
189
190
# Utility function for environment variables with defaults
# Assuming EnvVar and CacheState are defined elsewhere
class Environment:
    """
    Environment configuration for TileLang.
    Handles CUDA/ROCm detection, integration paths, template/cache locations,
    auto-tuning configs, and build options.
    """
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    # CUDA/ROCm home directories
    CUDA_HOME = _find_cuda_home()
    ROCM_HOME = _find_rocm_home()

    # Path to the TileLang package root
    TILELANG_PACKAGE_PATH = pathlib.Path(__file__).resolve().parent

    # External library include paths
    CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None)
    COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None)

    # TVM integration
    TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
    TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None)

    # TileLang resources
    TILELANG_TEMPLATE_PATH = EnvVar("TL_TEMPLATE_PATH", None)
    TILELANG_CACHE_DIR = EnvVar("TILELANG_CACHE_DIR", os.path.expanduser("~/.tilelang/cache"))
    TILELANG_TMP_DIR = EnvVar("TILELANG_TMP_DIR", os.path.join(TILELANG_CACHE_DIR.get(), "tmp"))

    # Kernel Build options
    TILELANG_PRINT_ON_COMPILATION = EnvVar("TILELANG_PRINT_ON_COMPILATION",
                                           "1")  # print kernel name on compile
    TILELANG_CLEAR_CACHE = EnvVar("TILELANG_CLEAR_CACHE", "0")  # clear cache automatically if set

    # Auto-tuning settings
    TILELANG_AUTO_TUNING_CPU_UTILITIES = EnvVar("TILELANG_AUTO_TUNING_CPU_UTILITIES",
                                                "0.9")  # percent of CPUs used
    TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS",
                                             "-1")  # -1 means auto
    TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT",
                                                "-1")  # -1 means no limit

    # TVM integration
    SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
    TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)

    def _initialize_torch_cuda_arch_flags(self) -> None:
        """
        Detect target CUDA architecture and set TORCH_CUDA_ARCH_LIST
        to ensure PyTorch extensions are built for the proper GPU arch.
        """
        from tilelang.contrib import nvcc
        from tilelang.utils.target import determine_target

        target = determine_target(return_object=True)  # get target GPU
        compute_version = nvcc.get_target_compute_version(target)  # e.g. "8.6"
        major, minor = nvcc.parse_compute_version(compute_version)  # split to (8, 6)
        os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"  # set env var for PyTorch

    # Cache control API (wrap CacheState)
    def is_cache_enabled(self) -> bool:
        return CacheState.is_enabled()

    def enable_cache(self) -> None:
        CacheState.enable()

    def disable_cache(self) -> None:
        CacheState.disable()

252
253
254
    def is_print_on_compilation_enabled(self) -> bool:
        return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")

255
256
257
258
259
260
261
262
263
264
265
266
267

# Instantiate as a global configuration object
env = Environment()

# Export CUDA_HOME and ROCM_HOME, both are static variables
# after initialization.
CUDA_HOME = env.CUDA_HOME
ROCM_HOME = env.ROCM_HOME

# Initialize TVM paths
if env.TVM_IMPORT_PYTHON_PATH is not None:
    os.environ["PYTHONPATH"] = env.TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
    sys.path.insert(0, env.TVM_IMPORT_PYTHON_PATH)
268
269
270
271
272
273
else:
    install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
    if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
        os.environ["PYTHONPATH"] = (
            install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
        sys.path.insert(0, install_tvm_path + "/python")
274
        env.TVM_IMPORT_PYTHON_PATH = install_tvm_path + "/python"
275
276
277
278
279
280
281

    develop_tvm_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
    if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
        os.environ["PYTHONPATH"] = (
            develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
        sys.path.insert(0, develop_tvm_path + "/python")
282
        env.TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
283

284
285
286
    develop_tvm_library_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
    install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
287
288
289
290
291
292
293
    if os.environ.get("TVM_LIBRARY_PATH") is None:
        if os.path.exists(develop_tvm_library_path):
            os.environ["TVM_LIBRARY_PATH"] = develop_tvm_library_path
        elif os.path.exists(install_tvm_library_path):
            os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
        else:
            logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
294
        # pip install build library path
295
        lib_path = os.path.join(env.TILELANG_PACKAGE_PATH, "lib")
296
297
298
299
300
        existing_path = os.environ.get("TVM_LIBRARY_PATH")
        if existing_path:
            os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}"
        else:
            os.environ["TVM_LIBRARY_PATH"] = lib_path
301
        env.TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
302

303
# Initialize CUTLASS paths
304
305
306
307
308
309
310
if os.environ.get("TL_CUTLASS_PATH", None) is None:
    install_cutlass_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
    develop_cutlass_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
    if os.path.exists(install_cutlass_path):
        os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
311
        env.CUTLASS_INCLUDE_DIR = install_cutlass_path + "/include"
312
313
    elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
        os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
314
        env.CUTLASS_INCLUDE_DIR = develop_cutlass_path + "/include"
315
316
317
    else:
        logger.warning(CUTLASS_NOT_FOUND_MESSAGE)

318
# Initialize COMPOSABLE_KERNEL paths
319
320
321
322
323
324
325
if os.environ.get("TL_COMPOSABLE_KERNEL_PATH", None) is None:
    install_ck_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "3rdparty", "composable_kernel")
    develop_ck_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "composable_kernel")
    if os.path.exists(install_ck_path):
        os.environ["TL_COMPOSABLE_KERNEL_PATH"] = install_ck_path + "/include"
326
        env.COMPOSABLE_KERNEL_INCLUDE_DIR = install_ck_path + "/include"
327
328
    elif (os.path.exists(develop_ck_path) and develop_ck_path not in sys.path):
        os.environ["TL_COMPOSABLE_KERNEL_PATH"] = develop_ck_path + "/include"
329
        env.COMPOSABLE_KERNEL_INCLUDE_DIR = develop_ck_path + "/include"
330
331
332
    else:
        logger.warning(COMPOSABLE_KERNEL_NOT_FOUND_MESSAGE)

333
# Initialize TL_TEMPLATE_PATH
334
335
336
337
338
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
    install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
    develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
    if os.path.exists(install_tl_template_path):
        os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
339
        env.TILELANG_TEMPLATE_PATH = install_tl_template_path
340
341
    elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
        os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
342
        env.TILELANG_TEMPLATE_PATH = develop_tl_template_path
343
344
345
    else:
        logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)

346
347
348
349
# Export static variables after initialization.
CUTLASS_INCLUDE_DIR = env.CUTLASS_INCLUDE_DIR
COMPOSABLE_KERNEL_INCLUDE_DIR = env.COMPOSABLE_KERNEL_INCLUDE_DIR
TILELANG_TEMPLATE_PATH = env.TILELANG_TEMPLATE_PATH