__init__.py 13.1 KB
Newer Older
1
2
3
4
5
6
"""
This module provides an auto-tuning infrastructure for TileLang (tl) programs. 
It includes functionality to JIT-compile TileLang programs into a runnable 
kernel adapter using TVM.
"""

7
8
9
10
11
12
13
14
15
16
17
from typing import (
    Any,
    List,
    Union,
    Callable,
    Tuple,
    overload,
    Literal,
    Dict,  # For type hinting dicts
    Optional,
)
18
19
20
21
22
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target

from tilelang.jit.kernel import JITKernel
23
from tilelang.cache import cached
24
from os import path, makedirs
25
from logging import getLogger
26
import functools
27
from tilelang.jit.param import Kernel, _P, _RProg
28
29
30
31

logger = getLogger(__name__)


32
33
def compile(
    func: PrimFunc = None,
34
    out_idx: Union[List[int], int, None] = None,
35
    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
36
37
38
    target: Union[str, Target] = "auto",
    target_host: Union[str, Target] = None,
    verbose: bool = False,
39
    pass_configs: Optional[Dict[str, Any]] = None,
40
    compile_flags: Optional[Union[List[str], str]] = None,
41
42
43
) -> JITKernel:
    """
    Compile the given TileLang PrimFunc with TVM and build a JITKernel.
44
45
46
47
48
49
    Parameters
    ----------
    func : tvm.tir.PrimFunc, optional
        The TileLang TIR function to compile and wrap.
    out_idx : Union[List[int], int], optional
        Index(es) of the output tensors to return (default: None).
50
51
    execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
        Execution backend to use for kernel execution (default: "cython").
52
53
54
55
56
57
58
59
60
61
62
    target : Union[str, Target], optional
        Compilation target, either as a string or a TVM Target object (default: "auto").
    target_host : Union[str, Target], optional
        Target host for cross-compilation (default: None).
    verbose : bool, optional
        Whether to enable verbose output (default: False).
    pass_configs : dict, optional
        Additional keyword arguments to pass to the Compiler PassContext.
        Available options:
            "tir.disable_vectorize": bool, default: False
            "tl.disable_tma_lower": bool, default: False
63
            "tl.disable_warp_specialized": bool, default: False
64
            "tl.config_index_bitwidth": int, default: None
65
66
            "tl.disable_dynamic_tail_split": bool, default: False
            "tl.dynamic_vectorize_size_bits": int, default: 128
67
            "tl.disable_safe_memory_legalize": bool, default: False
68
    """
69
    assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
70
71
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]
72
73
    return cached(
        func=func,
74
75
76
77
        out_idx=out_idx,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
78
79
        verbose=verbose,
        pass_configs=pass_configs,
80
        compile_flags=compile_flags,
81
82
83
    )


84
class _JitImplementation:
85

86
    out_idx: Optional[Union[List[int], int]]
87
88
89
90
91
92
    target: Union[str, Target]
    target_host: Union[str, Target]
    execution_backend: Literal["dlpack", "ctypes", "cython"]
    verbose: bool
    pass_configs: Optional[Dict[str, Any]]
    debug_root_path: Optional[str]
93
    compile_flags: Optional[List[str]]
94
95
96
97
98
99
100
101

    def __init__(self,
                 out_idx: Any = None,
                 target: Union[str, Target] = "auto",
                 target_host: Union[str, Target] = None,
                 execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
                 verbose: bool = False,
                 pass_configs: Optional[Dict[str, Any]] = None,
102
103
                 debug_root_path: Optional[str] = None,
                 compile_flags: Optional[List[str]] = None):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        """
        Initializes the JIT compiler decorator.

        Parameters
        ----------
        out_idx : Any, optional
            Index(es) of the output tensors to return from the compiled kernel
            (default: None, meaning all outputs are returned or determined by the kernel itself).
        target : Union[str, Target], optional
            Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
            or a TVM Target object. If "auto", the target is determined automatically
            (default: "auto").
        target_host : Union[str, Target], optional
            Target host for cross-compilation, similar to `target` (default: None).
        execution_backend : Literal["dlpack", "ctypes", "cython"], optional
            The backend used for kernel execution and argument passing.
            "dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
            "ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
            (default: "cython").
        verbose : bool, optional
            If True, enables verbose logging during compilation (default: False).
        pass_configs : Optional[Dict[str, Any]], optional
            A dictionary of configurations for TVM's pass context. These can fine-tune
            the compilation process. Examples include "tir.disable_vectorize"
            (default: None).
        debug_root_path : Optional[str], optional
            If provided, the compiled kernel's source code will be saved to a file
            in this directory. This is useful for debugging the generated code.
            If None, no debug information is saved (default: None).
            If a relative path is given, it's made absolute relative to the project root
            or current working directory.
        """
        self.out_idx = out_idx
        self.execution_backend = execution_backend
        self.target = target
        self.target_host = target_host
        self.verbose = verbose
        self.pass_configs = pass_configs
142
        self.compile_flags = compile_flags
143
144
145
146
147
148
149
150
151

        # Corrected debug_root_path handling
        self.debug_root_path = debug_root_path
        if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
            try:
                base_path = path.dirname(path.dirname(path.dirname(__file__)))
                self.debug_root_path = path.join(base_path, self.debug_root_path)
            except NameError:
                self.debug_root_path = path.abspath(self.debug_root_path)
152
153
154
155

        self._kernel_cache: Dict[tuple, Kernel] = {}

    # This tells the type checker what the *wrapper* function will return.
156
    # this is for linting, please do not remove it.
157
158
159
160
161
162
163
164
165
166
    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
        ...

    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
        ...

    # Actual implementation of __call__
    def __call__(
167
168
169
        self,
        func: Callable[_P, _RProg]  # func is Union[Callable[_P, _RProg], PrimFunc] in original
    ) -> Callable[_P, Any]:
170
171

        @functools.wraps(func)
172
        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
173
174
            # Separate out the tuning parameters from the user's kwargs
            tune_params = kwargs.pop('__tune_params', {})
175
176
177
178
179
180
181
182
183
184
            # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
            return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
            if return_compile_arguments:
                compile_args = {
                    'out_idx': self.out_idx,
                    'execution_backend': self.execution_backend,
                    'target': self.target,
                    'target_host': self.target_host,
                    'verbose': self.verbose,
                    'pass_configs': self.pass_configs,
185
                    'compile_flags': self.compile_flags,
186
187
                }
                return compile_args
188

189
190
            key_args_tuple = args
            key_kwargs_tuple = tuple(sorted(kwargs.items()))
191
192
            tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
            key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
193

194
            if key not in self._kernel_cache:
195
196
197
198
199
                # Ensure 'func' (the original user function) is used correctly
                program_result_source = func
                if isinstance(program_result_source, PrimFunc):
                    program_result = program_result_source
                elif callable(program_result_source):
200
                    program_result = program_result_source(*args, **kwargs, **tune_params)
201
                else:
202
                    raise ValueError(f"Invalid function type: {type(program_result_source)}")
203
204
205
206
207
208
209
210
211

                kernel_result = compile(
                    program_result,
                    out_idx=self.out_idx,
                    execution_backend=self.execution_backend,
                    target=self.target,
                    target_host=self.target_host,
                    verbose=self.verbose,
                    pass_configs=self.pass_configs,
212
                    compile_flags=self.compile_flags,
213
214
                )

215
216
                if self.debug_root_path:
                    func_name = getattr(func, '__name__', 'jit_kernel')  # Use func for name
217
                    kernel_file = f'tilelang_jit_kernel_{func_name}.c'
218
                    program_file = f'tilelang_jit_program_{func_name}.py'
219
220
221
                    makedirs(self.debug_root_path, exist_ok=True)
                    with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
                        print(kernel_result.get_kernel_source(), file=f)
222
223
                    with open(path.join(self.debug_root_path, program_file), 'w') as f:
                        print(program_result.script(), file=f)
224
225
226

                self._kernel_cache[key] = kernel_result

227
            return self._kernel_cache[key]
228
229

        return wrapper
230
231
232
233
234
235
236
237


def jit(  # This is the new public interface
        func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
        *,  # Indicates subsequent arguments are keyword-only
        out_idx: Any = None,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
Gabriel Wu's avatar
Gabriel Wu committed
238
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
239
240
        verbose: bool = False,
        pass_configs: Optional[Dict[str, Any]] = None,
241
242
        debug_root_path: Optional[str] = None,
        compile_flags: Optional[Union[List[str], str]] = None):
243
244
245
    """
    Just-In-Time (JIT) compiler decorator for TileLang functions.

246
    This decorator can be used without arguments (e.g., `@tilelang.jit`):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
       Applies JIT compilation with default settings.

    Parameters
    ----------
    func_or_out_idx : Any, optional
        If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
        If using `@tilelang.jit` directly on a function, this argument is implicitly
        the function to be decorated (and `out_idx` will be `None`).
    target : Union[str, Target], optional
        Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
    target_host : Union[str, Target], optional
        Target host for cross-compilation. Defaults to None.
    execution_backend : Literal["dlpack", "ctypes", "cython"], optional
        Backend for kernel execution and argument passing. Defaults to "cython".
    verbose : bool, optional
        Enables verbose logging during compilation. Defaults to False.
    pass_configs : Optional[Dict[str, Any]], optional
        Configurations for TVM's pass context. Defaults to None.
    debug_root_path : Optional[str], optional
        Directory to save compiled kernel source for debugging. Defaults to None.

    Returns
    -------
    Callable
        Either a JIT-compiled wrapper around the input function, or a configured decorator
        instance that can then be applied to a function.
    """
274
275
276
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]

277
278
279
280
281
282
283
284
285
286
    if callable(func):
        # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
        # Create a default _JitImplementation instance and apply it to the function.
        default_decorator = _JitImplementation(
            out_idx=out_idx,  # Explicitly None for the default case
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
287
288
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        return default_decorator(func)
    elif isinstance(func, PrimFunc):
        raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
    else:
        # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
        # Create a _JitImplementation instance with the provided/defaulted arguments.
        # This instance is a decorator that will be applied to the function later.
        configured_decorator = _JitImplementation(
            out_idx=out_idx,  # Pass along; could be an actual out_idx or None
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
303
304
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
305
        return configured_decorator