__init__.py 4.52 KB
Newer Older
1
2
3
4
5
import sys
import os
import ctypes

import logging
6
import warnings
7
from pathlib import Path
8
from tqdm.auto import tqdm
9

10

11
12
13
14
15
16
17
18
19
def _compute_version() -> str:
    """Return the package version without being polluted by unrelated installs.

    Preference order:
    1) If running from a source checkout (VERSION file present at repo root),
       use the dynamic version from version_provider (falls back to plain VERSION).
    2) Otherwise, use importlib.metadata for the installed distribution.
    3) As a last resort, return a dev sentinel.
    """
20
    try:
21
22
23
24
25
26
27
28
29
30
31
32
        repo_root = Path(__file__).resolve().parent.parent
        version_file = repo_root / "VERSION"
        if version_file.is_file():
            try:
                from version_provider import dynamic_metadata  # type: ignore
                return dynamic_metadata("version")
            except Exception:
                # Fall back to the raw VERSION file if provider isn't available.
                return version_file.read_text().strip()
    except Exception:
        # If any of the above fails, fall through to installed metadata.
        pass
33

34
35
36
    try:
        from importlib.metadata import version as _dist_version  # py3.8+
        return _dist_version("tilelang")
37
38
39
40
41
42
    except Exception as exc:
        warnings.warn(
            f"tilelang version metadata unavailable ({exc!r}); using development version.",
            RuntimeWarning,
            stacklevel=2,
        )
43
44
45
46
        return "0.0.dev0"


__version__ = _compute_version()
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

class TqdmLoggingHandler(logging.Handler):
    """Custom logging handler that directs log output to tqdm progress bar to avoid interference."""

    def __init__(self, level=logging.NOTSET):
        """Initialize the handler with an optional log level."""
        super().__init__(level)

    def emit(self, record):
        """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
        try:
            msg = self.format(record)
            tqdm.write(msg)
        except Exception:
            self.handleError(record)


def set_log_level(level):
    """Set the logging level for the module's logger.

    Args:
        level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
        OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
    """
    if isinstance(level, str):
        level = getattr(logging, level.upper(), logging.INFO)
    logger = logging.getLogger(__name__)
    logger.setLevel(level)


def _init_logger():
    """Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
    logger = logging.getLogger(__name__)
    handler = TqdmLoggingHandler()
    formatter = logging.Formatter(
83
        fmt="%(asctime)s  [TileLang:%(name)s:%(levelname)s]: %(message)s",
84
85
86
87
88
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False
89
    set_log_level("INFO")
90
91
92
93
94
95


_init_logger()

logger = logging.getLogger(__name__)

96
from .env import enable_cache, disable_cache, is_cache_enabled  # noqa: F401
97
from .env import env as env  # noqa: F401
98
99

import tvm
100
import tvm.base  # noqa: F401
101
from tvm import DataType  # noqa: F401
102

103
# Setup tvm search path before importing tvm
104
105
106
107
108
109
110
111
112
from . import libinfo


def _load_tile_lang_lib():
    """Load Tile Lang lib"""
    if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
        for path in libinfo.get_dll_directories():
            os.add_dll_directory(path)
    # pylint: disable=protected-access
113
    lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module"
114
    # pylint: enable=protected-access
115
116
    lib_path = libinfo.find_lib_path(lib_name)
    return ctypes.CDLL(lib_path), lib_path
117
118
119


# only load once here
120
if env.SKIP_LOADING_TILELANG_SO == "0":
121
122
    _LIB, _LIB_PATH = _load_tile_lang_lib()

123
from .jit import jit, JITKernel, compile  # noqa: F401
124
from .profiler import Profiler  # noqa: F401
125
from .cache import clear_cache  # noqa: F401
126

127
128
from .utils import (
    TensorSupplyType,  # noqa: F401
129
    deprecated,  # noqa: F401
130
131
132
133
134
135
)
from .layout import (
    Layout,  # noqa: F401
    Fragment,  # noqa: F401
)
from . import (
136
    analysis,  # noqa: F401
137
138
139
140
    transform,  # noqa: F401
    language,  # noqa: F401
    engine,  # noqa: F401
)
141
from .autotuner import autotune  # noqa: F401
142
from .transform import PassConfigKey  # noqa: F401
143

144
from .engine import lower, register_cuda_postproc, register_hip_postproc  # noqa: F401
145

146
from .math import *  # noqa: F403
147
148

from . import ir  # noqa: F401
149
150

from . import tileop  # noqa: F401