__init__.py 4.46 KB
Newer Older
1
2
3
4
5
6
"""
This module provides an auto-tuning infrastructure for TileLang (tl) programs. 
It includes functionality to JIT-compile TileLang programs into a runnable 
kernel adapter using TVM.
"""

7
from typing import Callable, List, Literal, Union, Any, Optional, Dict
8
9
10
11
12
13
14
15

from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target

from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
16
from tilelang.cache import cached
17
18
19
20
21
22
23
24
25
from logging import getLogger

logger = getLogger(__name__)


def jit(
    func: Callable = None,
    *,  # Enforce keyword-only arguments from here on
    out_idx: Union[List[int], int] = None,
26
    execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
27
28
    target: Union[str, Target] = "auto",
    verbose: bool = False,
29
    **pass_config_kwargs: Optional[Dict[str, Any]],
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
) -> BaseKernelAdapter:
    """
    A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc 
    into a runnable kernel adapter using TVM. If called with arguments, it returns 
    a decorator that can be applied to a function. If called without arguments, 
    it directly compiles the given function.

    Parameters
    ----------
    func : Callable, optional
        The TileLang PrimFunc to JIT-compile. If None, this function returns a 
        decorator that expects a TileLang PrimFunc.
    out_idx : Union[List[int], int], optional
        The index (or list of indices) of the function outputs. This can be used
        to specify which outputs from the compiled function will be returned.
45
    execution_backend : Literal["dlpack", "ctypes"], optional
46
        The wrapper type to use for the kernel adapter. Currently, only "dlpack"
47
        and "ctypes" are supported.
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    target : Union[str, Target], optional
        The compilation target for TVM. If set to "auto", an appropriate target
        will be inferred automatically. Otherwise, must be one of the supported
        strings in AVALIABLE_TARGETS or a TVM Target instance.

    Returns
    -------
    BaseKernelAdapter
        An adapter object that encapsulates the compiled function and can be
        used to execute it.

    Raises
    ------
    AssertionError
        If the provided target is an invalid string not present in AVALIABLE_TARGETS.
    """

    # If the target is specified as a string, ensure it is valid and convert to a TVM Target.
    if isinstance(target, str):
        assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
        target = determine_target(target)

    target = Target(target)

72
    assert execution_backend in ["dlpack", "ctypes", "cython"], "Invalid execution backend."
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
        """
        Compile the given TileLang PrimFunc with TVM and build a kernel adapter.

        Parameters
        ----------
        tilelang_func : tvm.tir.PrimFunc
            The TileLang (TVM TIR) function to compile.

        Returns
        -------
        BaseKernelAdapter
            The compiled and ready-to-run kernel adapter.
        """
        if verbose:
            logger.info(f"Compiling TileLang function:\n{tilelang_func}")
90
        return compile(
91
92
93
94
95
            tilelang_func,
            target=target,
            verbose=verbose,
            execution_backend=execution_backend,
            out_idx=out_idx,
96
            **pass_config_kwargs,
97
98
99
100
101
102
103
104
105
106
107
        ).adapter

    # If `func` was given, compile it immediately and return the adapter.
    if func is not None:
        return _compile_and_create_adapter(func)

    # Otherwise, return a decorator that expects a function to compile.
    def real_decorator(tilelang_func: PrimFunc) -> BaseKernelAdapter:
        return _compile_and_create_adapter(tilelang_func)

    return real_decorator
108
109
110
111
112


def compile(
    func: PrimFunc = None,
    out_idx: Union[List[int], int] = None,
113
    execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
114
115
116
    target: Union[str, Target] = "auto",
    target_host: Union[str, Target] = None,
    verbose: bool = False,
117
    pass_configs: Optional[Dict[str, Any]] = None,
118
119
120
121
) -> JITKernel:
    """
    Compile the given TileLang PrimFunc with TVM and build a JITKernel.
    """
122
123
    return cached(
        func=func,
124
125
126
127
        out_idx=out_idx,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
128
129
        verbose=verbose,
        pass_configs=pass_configs,
130
    )