__init__.py 13.5 KB
Newer Older
1
"""
2
3
This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
4
5
6
kernel adapter using TVM.
"""

7
8
9
10
11
12
13
14
15
16
17
from typing import (
    Any,
    List,
    Union,
    Callable,
    Tuple,
    overload,
    Literal,
    Dict,  # For type hinting dicts
    Optional,
)
18
19
20
21
22
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target

from tilelang.jit.kernel import JITKernel
23
from tilelang.utils.target import determine_target
24
from tilelang.cache import cached
25
from os import path, makedirs
26
from logging import getLogger
27
import functools
28
from tilelang.jit.param import Kernel, _P, _RProg
29
30
31
32

logger = getLogger(__name__)


33
34
def compile(
    func: PrimFunc = None,
35
    out_idx: Union[List[int], int, None] = None,
36
    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
37
    target: Union[str, Target] = "auto",
38
    target_host: Union[str, Target, None] = None,
39
    verbose: bool = False,
40
    pass_configs: Optional[Dict[str, Any]] = None,
41
    compile_flags: Optional[Union[List[str], str]] = None,
42
43
44
) -> JITKernel:
    """
    Compile the given TileLang PrimFunc with TVM and build a JITKernel.
45
46
47
48
49
50
    Parameters
    ----------
    func : tvm.tir.PrimFunc, optional
        The TileLang TIR function to compile and wrap.
    out_idx : Union[List[int], int], optional
        Index(es) of the output tensors to return (default: None).
51
52
    execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
        Execution backend to use for kernel execution (default: "cython").
53
54
55
56
57
58
59
60
61
62
63
    target : Union[str, Target], optional
        Compilation target, either as a string or a TVM Target object (default: "auto").
    target_host : Union[str, Target], optional
        Target host for cross-compilation (default: None).
    verbose : bool, optional
        Whether to enable verbose output (default: False).
    pass_configs : dict, optional
        Additional keyword arguments to pass to the Compiler PassContext.
        Available options:
            "tir.disable_vectorize": bool, default: False
            "tl.disable_tma_lower": bool, default: False
64
            "tl.disable_warp_specialized": bool, default: False
65
            "tl.config_index_bitwidth": int, default: None
66
67
            "tl.disable_dynamic_tail_split": bool, default: False
            "tl.dynamic_vectorize_size_bits": int, default: 128
68
            "tl.disable_safe_memory_legalize": bool, default: False
69
    """
70
    assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
71
72
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]
73
74
75
76

    # This path is not a performance critical path, so we can afford to convert the target.
    target = Target(determine_target(target))

77
78
    return cached(
        func=func,
79
80
81
82
        out_idx=out_idx,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
83
84
        verbose=verbose,
        pass_configs=pass_configs,
85
        compile_flags=compile_flags,
86
87
88
    )


89
class _JitImplementation:
90

91
    out_idx: Optional[Union[List[int], int]]
92
93
94
95
96
97
    target: Union[str, Target]
    target_host: Union[str, Target]
    execution_backend: Literal["dlpack", "ctypes", "cython"]
    verbose: bool
    pass_configs: Optional[Dict[str, Any]]
    debug_root_path: Optional[str]
98
    compile_flags: Optional[Union[List[str], str]]
99
100
101
102
103
104
105
106

    def __init__(self,
                 out_idx: Any = None,
                 target: Union[str, Target] = "auto",
                 target_host: Union[str, Target] = None,
                 execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
                 verbose: bool = False,
                 pass_configs: Optional[Dict[str, Any]] = None,
107
                 debug_root_path: Optional[str] = None,
108
                 compile_flags: Optional[Union[List[str], str]] = None):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        """
        Initializes the JIT compiler decorator.

        Parameters
        ----------
        out_idx : Any, optional
            Index(es) of the output tensors to return from the compiled kernel
            (default: None, meaning all outputs are returned or determined by the kernel itself).
        target : Union[str, Target], optional
            Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
            or a TVM Target object. If "auto", the target is determined automatically
            (default: "auto").
        target_host : Union[str, Target], optional
            Target host for cross-compilation, similar to `target` (default: None).
        execution_backend : Literal["dlpack", "ctypes", "cython"], optional
            The backend used for kernel execution and argument passing.
            "dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
            "ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
            (default: "cython").
        verbose : bool, optional
            If True, enables verbose logging during compilation (default: False).
        pass_configs : Optional[Dict[str, Any]], optional
            A dictionary of configurations for TVM's pass context. These can fine-tune
            the compilation process. Examples include "tir.disable_vectorize"
            (default: None).
        debug_root_path : Optional[str], optional
            If provided, the compiled kernel's source code will be saved to a file
            in this directory. This is useful for debugging the generated code.
            If None, no debug information is saved (default: None).
            If a relative path is given, it's made absolute relative to the project root
            or current working directory.
140
141
142
        compile_flags : Optional[Union[List[str], str]], optional
            Additional compilation flags to pass to the compiler.
            If None, no additional compilation flags are passed (default: None).
143
144
145
146
147
148
149
        """
        self.out_idx = out_idx
        self.execution_backend = execution_backend
        self.target = target
        self.target_host = target_host
        self.verbose = verbose
        self.pass_configs = pass_configs
150
        self.compile_flags = compile_flags
151
152
153
154
155
156
157
158
159

        # Corrected debug_root_path handling
        self.debug_root_path = debug_root_path
        if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
            try:
                base_path = path.dirname(path.dirname(path.dirname(__file__)))
                self.debug_root_path = path.join(base_path, self.debug_root_path)
            except NameError:
                self.debug_root_path = path.abspath(self.debug_root_path)
160
161
162
163

        self._kernel_cache: Dict[tuple, Kernel] = {}

    # This tells the type checker what the *wrapper* function will return.
164
    # this is for linting, please do not remove it.
165
166
167
168
169
170
171
172
173
174
    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
        ...

    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
        ...

    # Actual implementation of __call__
    def __call__(
175
176
177
        self,
        func: Callable[_P, _RProg]  # func is Union[Callable[_P, _RProg], PrimFunc] in original
    ) -> Callable[_P, Any]:
178
179

        @functools.wraps(func)
180
        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
181
182
            # Separate out the tuning parameters from the user's kwargs
            tune_params = kwargs.pop('__tune_params', {})
183
184
185
186
187
188
189
190
191
192
            # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
            return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
            if return_compile_arguments:
                compile_args = {
                    'out_idx': self.out_idx,
                    'execution_backend': self.execution_backend,
                    'target': self.target,
                    'target_host': self.target_host,
                    'verbose': self.verbose,
                    'pass_configs': self.pass_configs,
193
                    'compile_flags': self.compile_flags,
194
195
                }
                return compile_args
196

197
198
            key_args_tuple = args
            key_kwargs_tuple = tuple(sorted(kwargs.items()))
199
200
            tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
            key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
201

202
            if key not in self._kernel_cache:
203
204
205
206
207
                # Ensure 'func' (the original user function) is used correctly
                program_result_source = func
                if isinstance(program_result_source, PrimFunc):
                    program_result = program_result_source
                elif callable(program_result_source):
208
                    program_result = program_result_source(*args, **kwargs, **tune_params)
209
                else:
210
                    raise ValueError(f"Invalid function type: {type(program_result_source)}")
211
212
213
214
215
216
217
218
219

                kernel_result = compile(
                    program_result,
                    out_idx=self.out_idx,
                    execution_backend=self.execution_backend,
                    target=self.target,
                    target_host=self.target_host,
                    verbose=self.verbose,
                    pass_configs=self.pass_configs,
220
                    compile_flags=self.compile_flags,
221
222
                )

223
224
                if self.debug_root_path:
                    func_name = getattr(func, '__name__', 'jit_kernel')  # Use func for name
225
                    kernel_file = f'tilelang_jit_kernel_{func_name}.c'
226
                    program_file = f'tilelang_jit_program_{func_name}.py'
227
228
229
                    makedirs(self.debug_root_path, exist_ok=True)
                    with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
                        print(kernel_result.get_kernel_source(), file=f)
230
231
                    with open(path.join(self.debug_root_path, program_file), 'w') as f:
                        print(program_result.script(), file=f)
232
233
234

                self._kernel_cache[key] = kernel_result

235
            return self._kernel_cache[key]
236
237

        return wrapper
238
239
240
241
242
243
244
245


def jit(  # This is the new public interface
        func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
        *,  # Indicates subsequent arguments are keyword-only
        out_idx: Any = None,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
Gabriel Wu's avatar
Gabriel Wu committed
246
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
247
248
        verbose: bool = False,
        pass_configs: Optional[Dict[str, Any]] = None,
249
250
        debug_root_path: Optional[str] = None,
        compile_flags: Optional[Union[List[str], str]] = None):
251
252
253
    """
    Just-In-Time (JIT) compiler decorator for TileLang functions.

254
    This decorator can be used without arguments (e.g., `@tilelang.jit`):
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
       Applies JIT compilation with default settings.

    Parameters
    ----------
    func_or_out_idx : Any, optional
        If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
        If using `@tilelang.jit` directly on a function, this argument is implicitly
        the function to be decorated (and `out_idx` will be `None`).
    target : Union[str, Target], optional
        Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
    target_host : Union[str, Target], optional
        Target host for cross-compilation. Defaults to None.
    execution_backend : Literal["dlpack", "ctypes", "cython"], optional
        Backend for kernel execution and argument passing. Defaults to "cython".
    verbose : bool, optional
        Enables verbose logging during compilation. Defaults to False.
    pass_configs : Optional[Dict[str, Any]], optional
        Configurations for TVM's pass context. Defaults to None.
    debug_root_path : Optional[str], optional
        Directory to save compiled kernel source for debugging. Defaults to None.

    Returns
    -------
    Callable
        Either a JIT-compiled wrapper around the input function, or a configured decorator
        instance that can then be applied to a function.
    """
282
283
284
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]

285
286
287
288
289
290
291
292
293
294
    if callable(func):
        # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
        # Create a default _JitImplementation instance and apply it to the function.
        default_decorator = _JitImplementation(
            out_idx=out_idx,  # Explicitly None for the default case
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
295
296
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        return default_decorator(func)
    elif isinstance(func, PrimFunc):
        raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
    else:
        # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
        # Create a _JitImplementation instance with the provided/defaulted arguments.
        # This instance is a decorator that will be applied to the function later.
        configured_decorator = _JitImplementation(
            out_idx=out_idx,  # Pass along; could be an actual out_idx or None
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
311
312
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
313
        return configured_decorator