__init__.py 3.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
import os
import ctypes

import logging
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
    """Custom logging handler that directs log output to tqdm progress bar to avoid interference."""

    def __init__(self, level=logging.NOTSET):
        """Initialize the handler with an optional log level."""
        super().__init__(level)

    def emit(self, record):
        """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
        try:
            msg = self.format(record)
            tqdm.write(msg)
        except Exception:
            self.handleError(record)


def set_log_level(level):
    """Set the logging level for the module's logger.

    Args:
        level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
        OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
    """
    if isinstance(level, str):
        level = getattr(logging, level.upper(), logging.INFO)
    logger = logging.getLogger(__name__)
    logger.setLevel(level)


def _init_logger():
    """Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
    logger = logging.getLogger(__name__)
    handler = TqdmLoggingHandler()
    formatter = logging.Formatter(
43
        fmt="%(asctime)s  [TileLang:%(name)s:%(levelname)s]: %(message)s",
44
45
46
47
48
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False
49
    set_log_level("INFO")
50
51
52
53
54
55


_init_logger()

logger = logging.getLogger(__name__)

56
from .env import enable_cache, disable_cache, is_cache_enabled  # noqa: F401
57
from .env import env as env  # noqa: F401
58
59

import tvm
60
import tvm.base
61
from tvm import DataType  # noqa: F401
62
63
64
65
66
67
68
69
70
71

from . import libinfo


def _load_tile_lang_lib():
    """Load Tile Lang lib"""
    if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
        for path in libinfo.get_dll_directories():
            os.add_dll_directory(path)
    # pylint: disable=protected-access
72
    lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module"
73
74
75
76
77
78
    # pylint: enable=protected-access
    lib_path = libinfo.find_lib_path(lib_name, optional=False)
    return ctypes.CDLL(lib_path[0]), lib_path[0]


# only load once here
79
if env.SKIP_LOADING_TILELANG_SO == "0":
80
81
    _LIB, _LIB_PATH = _load_tile_lang_lib()

82
from .jit import jit, JITKernel, compile  # noqa: F401
83
from .profiler import Profiler  # noqa: F401
84
from .cache import clear_cache  # noqa: F401
85

86
87
from .utils import (
    TensorSupplyType,  # noqa: F401
88
    deprecated,  # noqa: F401
89
90
91
92
93
94
95
96
97
98
)
from .layout import (
    Layout,  # noqa: F401
    Fragment,  # noqa: F401
)
from . import (
    transform,  # noqa: F401
    language,  # noqa: F401
    engine,  # noqa: F401
)
99
from .autotuner import autotune  # noqa: F401
100
from .transform import PassConfigKey  # noqa: F401
101

102
from .engine import lower, register_cuda_postproc, register_hip_postproc  # noqa: F401
103
104

from .version import __version__  # noqa: F401
105
106

from .math import *  # noqa: F403