__init__.py 3.59 KB
Newer Older
1
2
3
4
5
import sys
import os
import ctypes

import logging
6
import warnings
7
8
from tqdm import tqdm

9
from importlib.metadata import PackageNotFoundError, version
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
try:
    __version__ = version('tilelang')
except PackageNotFoundError:
    try:
        from version_provider import dynamic_metadata

        __version__ = dynamic_metadata('version')
    except Exception as exc:
        warnings.warn(
            f"tilelang version metadata unavailable ({exc!r}); using development version.",
            RuntimeWarning,
            stacklevel=2,
        )
        __version__ = "0.0.dev0"
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

class TqdmLoggingHandler(logging.Handler):
    """Custom logging handler that directs log output to tqdm progress bar to avoid interference."""

    def __init__(self, level=logging.NOTSET):
        """Initialize the handler with an optional log level."""
        super().__init__(level)

    def emit(self, record):
        """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
        try:
            msg = self.format(record)
            tqdm.write(msg)
        except Exception:
            self.handleError(record)


def set_log_level(level):
    """Set the logging level for the module's logger.

    Args:
        level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
        OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
    """
    if isinstance(level, str):
        level = getattr(logging, level.upper(), logging.INFO)
    logger = logging.getLogger(__name__)
    logger.setLevel(level)


def _init_logger():
    """Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
    logger = logging.getLogger(__name__)
    handler = TqdmLoggingHandler()
    formatter = logging.Formatter(
61
        fmt="%(asctime)s  [TileLang:%(name)s:%(levelname)s]: %(message)s",
62
63
64
65
66
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False
67
    set_log_level("INFO")
68
69
70
71
72
73


_init_logger()

logger = logging.getLogger(__name__)

74
from .env import enable_cache, disable_cache, is_cache_enabled  # noqa: F401
75
from .env import env as env  # noqa: F401
76
77

import tvm
78
import tvm.base  # noqa: F401
79
from tvm import DataType  # noqa: F401
80

81
# Setup tvm search path before importing tvm
82
83
84
85
86
87
88
89
90
from . import libinfo


def _load_tile_lang_lib():
    """Load Tile Lang lib"""
    if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
        for path in libinfo.get_dll_directories():
            os.add_dll_directory(path)
    # pylint: disable=protected-access
91
    lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module"
92
    # pylint: enable=protected-access
93
94
    lib_path = libinfo.find_lib_path(lib_name)
    return ctypes.CDLL(lib_path), lib_path
95
96
97


# only load once here
98
if env.SKIP_LOADING_TILELANG_SO == "0":
99
100
    _LIB, _LIB_PATH = _load_tile_lang_lib()

101
from .jit import jit, JITKernel, compile  # noqa: F401
102
from .profiler import Profiler  # noqa: F401
103
from .cache import clear_cache  # noqa: F401
104

105
106
from .utils import (
    TensorSupplyType,  # noqa: F401
107
    deprecated,  # noqa: F401
108
109
110
111
112
113
114
115
116
117
)
from .layout import (
    Layout,  # noqa: F401
    Fragment,  # noqa: F401
)
from . import (
    transform,  # noqa: F401
    language,  # noqa: F401
    engine,  # noqa: F401
)
118
from .autotuner import autotune  # noqa: F401
119
from .transform import PassConfigKey  # noqa: F401
120

121
from .engine import lower, register_cuda_postproc, register_hip_postproc  # noqa: F401
122

123
from .math import *  # noqa: F403
124
125

from . import ir  # noqa: F401
126
127

from . import tileop  # noqa: F401