__init__.py 13.3 KB
Newer Older
1
2
3
4
5
6
"""
This module provides an auto-tuning infrastructure for TileLang (tl) programs. 
It includes functionality to JIT-compile TileLang programs into a runnable 
kernel adapter using TVM.
"""

7
8
9
10
11
12
13
14
15
16
17
from typing import (
    Any,
    List,
    Union,
    Callable,
    Tuple,
    overload,
    Literal,
    Dict,  # For type hinting dicts
    Optional,
)
18
19
20
21
22
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target

from tilelang.jit.kernel import JITKernel
23
from tilelang.utils.target import determine_target
24
from tilelang.cache import cached
25
from os import path, makedirs
26
from logging import getLogger
27
import functools
28
from tilelang.jit.param import Kernel, _P, _RProg
29
30
31
32

logger = getLogger(__name__)


33
34
def compile(
    func: PrimFunc = None,
35
    out_idx: Union[List[int], int, None] = None,
36
    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
37
    target: Union[str, Target] = "auto",
38
    target_host: Union[str, Target, None] = None,
39
    verbose: bool = False,
40
    pass_configs: Optional[Dict[str, Any]] = None,
41
    compile_flags: Optional[Union[List[str], str]] = None,
42
43
44
) -> JITKernel:
    """
    Compile the given TileLang PrimFunc with TVM and build a JITKernel.
45
46
47
48
49
50
    Parameters
    ----------
    func : tvm.tir.PrimFunc, optional
        The TileLang TIR function to compile and wrap.
    out_idx : Union[List[int], int], optional
        Index(es) of the output tensors to return (default: None).
51
52
    execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
        Execution backend to use for kernel execution (default: "cython").
53
54
55
56
57
58
59
60
61
62
63
    target : Union[str, Target], optional
        Compilation target, either as a string or a TVM Target object (default: "auto").
    target_host : Union[str, Target], optional
        Target host for cross-compilation (default: None).
    verbose : bool, optional
        Whether to enable verbose output (default: False).
    pass_configs : dict, optional
        Additional keyword arguments to pass to the Compiler PassContext.
        Available options:
            "tir.disable_vectorize": bool, default: False
            "tl.disable_tma_lower": bool, default: False
64
            "tl.disable_warp_specialized": bool, default: False
65
            "tl.config_index_bitwidth": int, default: None
66
67
            "tl.disable_dynamic_tail_split": bool, default: False
            "tl.dynamic_vectorize_size_bits": int, default: 128
68
            "tl.disable_safe_memory_legalize": bool, default: False
69
    """
70
    assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
71
72
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]
73
74
75
76

    # This path is not a performance critical path, so we can afford to convert the target.
    target = Target(determine_target(target))

77
78
    return cached(
        func=func,
79
80
81
82
        out_idx=out_idx,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
83
84
        verbose=verbose,
        pass_configs=pass_configs,
85
        compile_flags=compile_flags,
86
87
88
    )


89
class _JitImplementation:
90

91
    out_idx: Optional[Union[List[int], int]]
92
93
94
95
96
97
    target: Union[str, Target]
    target_host: Union[str, Target]
    execution_backend: Literal["dlpack", "ctypes", "cython"]
    verbose: bool
    pass_configs: Optional[Dict[str, Any]]
    debug_root_path: Optional[str]
98
    compile_flags: Optional[List[str]]
99
100
101
102
103
104
105
106

    def __init__(self,
                 out_idx: Any = None,
                 target: Union[str, Target] = "auto",
                 target_host: Union[str, Target] = None,
                 execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
                 verbose: bool = False,
                 pass_configs: Optional[Dict[str, Any]] = None,
107
108
                 debug_root_path: Optional[str] = None,
                 compile_flags: Optional[List[str]] = None):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        """
        Initializes the JIT compiler decorator.

        Parameters
        ----------
        out_idx : Any, optional
            Index(es) of the output tensors to return from the compiled kernel
            (default: None, meaning all outputs are returned or determined by the kernel itself).
        target : Union[str, Target], optional
            Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
            or a TVM Target object. If "auto", the target is determined automatically
            (default: "auto").
        target_host : Union[str, Target], optional
            Target host for cross-compilation, similar to `target` (default: None).
        execution_backend : Literal["dlpack", "ctypes", "cython"], optional
            The backend used for kernel execution and argument passing.
            "dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
            "ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
            (default: "cython").
        verbose : bool, optional
            If True, enables verbose logging during compilation (default: False).
        pass_configs : Optional[Dict[str, Any]], optional
            A dictionary of configurations for TVM's pass context. These can fine-tune
            the compilation process. Examples include "tir.disable_vectorize"
            (default: None).
        debug_root_path : Optional[str], optional
            If provided, the compiled kernel's source code will be saved to a file
            in this directory. This is useful for debugging the generated code.
            If None, no debug information is saved (default: None).
            If a relative path is given, it's made absolute relative to the project root
            or current working directory.
        """
        self.out_idx = out_idx
        self.execution_backend = execution_backend
        self.target = target
        self.target_host = target_host
        self.verbose = verbose
        self.pass_configs = pass_configs
147
        self.compile_flags = compile_flags
148
149
150
151
152
153
154
155
156

        # Corrected debug_root_path handling
        self.debug_root_path = debug_root_path
        if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
            try:
                base_path = path.dirname(path.dirname(path.dirname(__file__)))
                self.debug_root_path = path.join(base_path, self.debug_root_path)
            except NameError:
                self.debug_root_path = path.abspath(self.debug_root_path)
157
158
159
160

        self._kernel_cache: Dict[tuple, Kernel] = {}

    # This tells the type checker what the *wrapper* function will return.
161
    # this is for linting, please do not remove it.
162
163
164
165
166
167
168
169
170
171
    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
        ...

    @overload
    def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
        ...

    # Actual implementation of __call__
    def __call__(
172
173
174
        self,
        func: Callable[_P, _RProg]  # func is Union[Callable[_P, _RProg], PrimFunc] in original
    ) -> Callable[_P, Any]:
175
176

        @functools.wraps(func)
177
        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
178
179
            # Separate out the tuning parameters from the user's kwargs
            tune_params = kwargs.pop('__tune_params', {})
180
181
182
183
184
185
186
187
188
189
            # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
            return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
            if return_compile_arguments:
                compile_args = {
                    'out_idx': self.out_idx,
                    'execution_backend': self.execution_backend,
                    'target': self.target,
                    'target_host': self.target_host,
                    'verbose': self.verbose,
                    'pass_configs': self.pass_configs,
190
                    'compile_flags': self.compile_flags,
191
192
                }
                return compile_args
193

194
195
            key_args_tuple = args
            key_kwargs_tuple = tuple(sorted(kwargs.items()))
196
197
            tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
            key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
198

199
            if key not in self._kernel_cache:
200
201
202
203
204
                # Ensure 'func' (the original user function) is used correctly
                program_result_source = func
                if isinstance(program_result_source, PrimFunc):
                    program_result = program_result_source
                elif callable(program_result_source):
205
                    program_result = program_result_source(*args, **kwargs, **tune_params)
206
                else:
207
                    raise ValueError(f"Invalid function type: {type(program_result_source)}")
208
209
210
211
212
213
214
215
216

                kernel_result = compile(
                    program_result,
                    out_idx=self.out_idx,
                    execution_backend=self.execution_backend,
                    target=self.target,
                    target_host=self.target_host,
                    verbose=self.verbose,
                    pass_configs=self.pass_configs,
217
                    compile_flags=self.compile_flags,
218
219
                )

220
221
                if self.debug_root_path:
                    func_name = getattr(func, '__name__', 'jit_kernel')  # Use func for name
222
                    kernel_file = f'tilelang_jit_kernel_{func_name}.c'
223
                    program_file = f'tilelang_jit_program_{func_name}.py'
224
225
226
                    makedirs(self.debug_root_path, exist_ok=True)
                    with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
                        print(kernel_result.get_kernel_source(), file=f)
227
228
                    with open(path.join(self.debug_root_path, program_file), 'w') as f:
                        print(program_result.script(), file=f)
229
230
231

                self._kernel_cache[key] = kernel_result

232
            return self._kernel_cache[key]
233
234

        return wrapper
235
236
237
238
239
240
241
242


def jit(  # This is the new public interface
        func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
        *,  # Indicates subsequent arguments are keyword-only
        out_idx: Any = None,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
Gabriel Wu's avatar
Gabriel Wu committed
243
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
244
245
        verbose: bool = False,
        pass_configs: Optional[Dict[str, Any]] = None,
246
247
        debug_root_path: Optional[str] = None,
        compile_flags: Optional[Union[List[str], str]] = None):
248
249
250
    """
    Just-In-Time (JIT) compiler decorator for TileLang functions.

251
    This decorator can be used without arguments (e.g., `@tilelang.jit`):
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
       Applies JIT compilation with default settings.

    Parameters
    ----------
    func_or_out_idx : Any, optional
        If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
        If using `@tilelang.jit` directly on a function, this argument is implicitly
        the function to be decorated (and `out_idx` will be `None`).
    target : Union[str, Target], optional
        Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
    target_host : Union[str, Target], optional
        Target host for cross-compilation. Defaults to None.
    execution_backend : Literal["dlpack", "ctypes", "cython"], optional
        Backend for kernel execution and argument passing. Defaults to "cython".
    verbose : bool, optional
        Enables verbose logging during compilation. Defaults to False.
    pass_configs : Optional[Dict[str, Any]], optional
        Configurations for TVM's pass context. Defaults to None.
    debug_root_path : Optional[str], optional
        Directory to save compiled kernel source for debugging. Defaults to None.

    Returns
    -------
    Callable
        Either a JIT-compiled wrapper around the input function, or a configured decorator
        instance that can then be applied to a function.
    """
279
280
281
    if isinstance(compile_flags, str):
        compile_flags = [compile_flags]

282
283
284
285
286
287
288
289
290
291
    if callable(func):
        # Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
        # Create a default _JitImplementation instance and apply it to the function.
        default_decorator = _JitImplementation(
            out_idx=out_idx,  # Explicitly None for the default case
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
292
293
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        return default_decorator(func)
    elif isinstance(func, PrimFunc):
        raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
    else:
        # Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
        # Create a _JitImplementation instance with the provided/defaulted arguments.
        # This instance is a decorator that will be applied to the function later.
        configured_decorator = _JitImplementation(
            out_idx=out_idx,  # Pass along; could be an actual out_idx or None
            target=target,
            target_host=target_host,
            execution_backend=execution_backend,
            verbose=verbose,
            pass_configs=pass_configs,
308
309
            debug_root_path=debug_root_path,
            compile_flags=compile_flags)
310
        return configured_decorator