"docs/en_US/vscode:/vscode.git/clone" did not exist on "8fb8f8b3cef0757401c1a7ed0fb1c8e3f659c5c4"
__init__.py 22.7 KB
Newer Older
1
"""
2
3
This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
4
5
kernel adapter using TVM.
"""
6

7
from __future__ import annotations
8

9
10
from dataclasses import dataclass
import inspect
11
12
13
from typing import (
    Any,
    Callable,
14
15
    Generic,
    TypeVar,
16
17
18
    overload,
    Literal,
)
19
from collections.abc import Iterable
20

21
22
23
24
25
# Python 3.9 compatibility for ParamSpec
try:
    from typing import ParamSpec
except ImportError:  # Python < 3.10
    from typing_extensions import ParamSpec
26
from tilelang import tvm as tvm
27
28
from tilelang.language.v2 import PrimFunc, PrimFuncCreater, prim_func
from tilelang.language.v2.annot import Annot
29
30
31
from tvm.target import Target

from tilelang.jit.kernel import JITKernel
32
from tilelang.utils.target import determine_target
33
from tilelang.cache import cached
34
from os import path, makedirs
35
from logging import getLogger
36
37
38
39
from tilelang.jit.param import Kernel
import concurrent.futures

from tqdm.auto import tqdm
40
41
42

logger = getLogger(__name__)

43
44
45
46
_P = ParamSpec("_P")
_KP = ParamSpec("_KP")
_T = TypeVar("_T")
_Ret = TypeVar("_Ret")
47

48

49
def compile(
50
    func: PrimFunc[_KP, _T] = None,
51
    out_idx: list[int] | int | None = None,
52
    execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
53
54
    target: str | Target = "auto",
    target_host: str | Target | None = None,
55
    verbose: bool = False,
56
57
    pass_configs: dict[str, Any] | None = None,
    compile_flags: list[str] | str | None = None,
58
) -> JITKernel[_KP, _T]:
59
60
    """
    Compile the given TileLang PrimFunc with TVM and build a JITKernel.
61
62
63
64
65
66
    Parameters
    ----------
    func : tvm.tir.PrimFunc, optional
        The TileLang TIR function to compile and wrap.
    out_idx : Union[List[int], int], optional
        Index(es) of the output tensors to return (default: None).
67
    execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
68
69
        Execution backend to use for kernel execution. Use "auto" to pick a sensible
        default per target (cuda->tvm_ffi, metal->torch, others->cython).
70
71
72
73
74
75
76
77
    target : Union[str, Target], optional
        Compilation target, either as a string or a TVM Target object (default: "auto").
    target_host : Union[str, Target], optional
        Target host for cross-compilation (default: None).
    verbose : bool, optional
        Whether to enable verbose output (default: False).
    pass_configs : dict, optional
        Additional keyword arguments to pass to the Compiler PassContext.
78
        Refer to `tilelang.transform.PassConfigKey` for supported options.
79
    """
80

81
    assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
82

83
    if hasattr(func, "out_idx_override"):
84
        if func.out_idx_override is not None and out_idx is not None:
85
            raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors")
86
87
        out_idx = func.out_idx_override or out_idx

88
89
90
    # This path is not a performance critical path, so we can afford to convert the target.
    target = Target(determine_target(target))

91
92
93
    # Resolve execution backend (handles aliases, auto, validation per target)
    requested_backend = execution_backend
    from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
94

95
96
97
98
99
100
101
102
103
104
    execution_backend = resolve_execution_backend(requested_backend, target)
    if verbose:
        allowed_now = allowed_backends_for_target(target, include_unavailable=False)
        logger.info(
            "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)",
            execution_backend,
            requested_backend,
            target.kind.name,
            ", ".join(sorted(allowed_now)),
        )
105

106
107
    return cached(
        func=func,
108
109
110
111
        out_idx=out_idx,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
112
113
        verbose=verbose,
        pass_configs=pass_configs,
114
        compile_flags=compile_flags,
115
116
117
    )


118
119
120
def par_compile(
    funcs: Iterable[PrimFunc[_KP, _T]],
    out_idx: list[int] | int | None = None,
121
    execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
122
123
124
125
126
127
128
129
    target: str | Target = "auto",
    target_host: str | Target | None = None,
    verbose: bool = False,
    pass_configs: dict[str, Any] | None = None,
    compile_flags: list[str] | str | None = None,
    num_workers: int = None,
    ignore_error: bool = False,
) -> list[JITKernel[_KP, _T]]:
130
131
132
133
134
135
136
137
    """
    Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
    Parameters
    ----------
    funcs : Iterable[tvm.tir.PrimFunc]
        The TileLang TIR functions to compile and wrap.
    out_idx : Union[List[int], int], optional
        Index(es) of the output tensors to return (default: None).
138
    execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
139
140
        Execution backend to use for kernel execution. Use "auto" to pick a sensible
        default per target (cuda->tvm_ffi, metal->torch, others->cython).
141
142
143
144
145
146
147
148
149
150
    target : Union[str, Target], optional
        Compilation target, either as a string or a TVM Target object (default: "auto").
    target_host : Union[str, Target], optional
        Target host for cross-compilation (default: None).
    verbose : bool, optional
        Whether to enable verbose output (default: False).
    pass_configs : dict, optional
        Additional keyword arguments to pass to the Compiler PassContext.
        Refer to `tilelang.transform.PassConfigKey` for supported options.
    """
151
    with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        futures = []
        future_map = {}
        for i, func in enumerate(funcs):
            future = executor.submit(
                compile,
                func=func,
                out_idx=out_idx,
                execution_backend=execution_backend,
                target=target,
                target_host=target_host,
                verbose=verbose,
                pass_configs=pass_configs,
                compile_flags=compile_flags,
            )
            future_map[future] = i
            futures.append(future)
        results = [... for _ in futures]
        for future in tqdm(
170
171
172
            concurrent.futures.as_completed(futures),
            total=len(futures),
            desc="Parallel Compiling",
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        ):
            idx = future_map[future]
            if ignore_error:
                try:
                    results[idx] = future.result()
                except Exception as e:
                    logger.warning(f"Error compiling function at index {idx}: {e}")
                    results[idx] = None
            else:
                results[idx] = future.result()
        return results
    return results


@dataclass
188
class JITImpl(Generic[_P, _KP, _T, _Ret]):
189
    """
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    Detailed Just-In-Time wrapper for TileLang programs.

    This dataclass encapsulates the configuration and runtime helpers used by the
    top-level `jit` and `jit2` decorators. It represents a configured JIT
    "factory" that can (a) elaborate TileLang/PrimFunc creators into concrete
    TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via
    the TVM bridge, (c) cache compiled kernels keyed by call-site arguments
    (and optional tuning parameters), and (d) provide parallel compilation
    helpers for batch autotuning workflows.

    Attributes
    ----------
    out_idx : list[int] | int | None
        Which output tensor(s) of the compiled kernel should be returned to the
        caller. Accepts a single index, a list of indices, or None to return all.
    execution_backend : Literal["dlpack", "ctypes", "cython"]
        Backend used for exchanging arguments and executing the generated kernel.
    target : str | tvm.target.Target
        TVM compilation target (e.g. "cuda", "llvm", or "auto").
    target_host : str | tvm.target.Target | None
        Host target used for cross-compilation, or None to infer/default.
    verbose : bool
        Enable verbose messages during compilation/build.
    pass_configs : dict[str, Any] | None
        Extra TVM pass configuration options forwarded to the compiler's
        PassContext.
    debug_root_path : str | None
        If provided, compiled kernel source and the elaborated Python program
        are written to this directory to ease debugging and inspection.
    compile_flags : list[str] | str | None
        Additional flags passed to the compiler. A single string will be converted
        to a single-element list.
    func_source : str
        Original Python source string from which the PrimFunc or creator was
        derived. Used for diagnostics and debug dumps.
    signature : inspect.Signature
        Function signature of the original Python function (useful for tooling).
    v2 : bool
        Indicates whether the object wraps a "v2" PrimFunc creator (True) or a
        plain callable / PrimFunc (False). v2-mode enables argument conversion
        hooks and a distinct cache keying strategy.
    func : Callable | PrimFunc | PrimFuncCreater
        The underlying object: either a user function that returns a PrimFunc
        (creator), a PrimFuncCreater, or an already-constructed PrimFunc.
        For presentation/readability the function is stored last in the dataclass.

    Behavioral summary
    ------------------
    - get_tir(*args, **kwargs)
        Converts provided call-site arguments into a concrete PrimFunc. If the
        wrapped object is a PrimFuncCreater or a user callable, it is invoked
        with the given arguments. If the wrapped object is already a PrimFunc,
        it is returned as-is.

    - compile(...)
        A convenience wrapper that elaborates and immediately compiles a single
        PrimFunc into a JITKernel using the module-level `compile` function.
        When `debug_root_path` is set, the compiled C kernel and the source
        Python program are saved for inspection.

    - par_compile(configs, ...)
        Accepts an iterable of configs (either dicts mapping keyword args or
        tuples mapping to positional args). Each config is elaborated to a
        PrimFunc and the resulting set is compiled in parallel via the
        module-level `par_compile` helper. Returns a list of JITKernel objects
        in the same order as the provided configs.
256
    """
257

258
    out_idx: list[int] | int | None
259
    execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
260
261
    target: str | Target
    target_host: str | Target
262
    verbose: bool
263
264
265
    pass_configs: dict[str, Any] | None
    debug_root_path: str | None
    compile_flags: list[str] | str | None
266
267
    func_source: str
    signature: inspect.Signature
268
269
270
271
272
273
274
275
    lazy_jit: bool
    # place func at the last element for better __repr__
    func: Callable[_P, _T] | PrimFunc[_KP, _T]

    @property
    def annot(self) -> dict[str, Annot]:
        assert self.lazy_jit, "annot is only support in @tilelang.jit2"
        return self.func.func_annot.annots
276

277
    def __post_init__(self):
278
279
280
281
282
283
        if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
            try:
                base_path = path.dirname(path.dirname(path.dirname(__file__)))
                self.debug_root_path = path.join(base_path, self.debug_root_path)
            except NameError:
                self.debug_root_path = path.abspath(self.debug_root_path)
284
        self._kernel_cache: dict[tuple, Kernel] = {}
285
        self._tuner_cache: dict[tuple, Kernel] = {}
286

287
    def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
288
289
290
291
292
293
294
295
296
        """
        Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
        """
        if isinstance(self.func, PrimFuncCreater):
            tir = self.func(*args, **kwargs)
        elif isinstance(self.func, PrimFunc):
            tir = self.func
        elif callable(self.func):
            tir = self.func(*args, **kwargs)
297
        else:
298
299
300
            raise ValueError(f"Invalid function type: {type(self.func)}")
        assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
        return tir
301

302
303
304
    def par_compile(
        self, configs: Iterable[dict[str, Any] | tuple[str, Any]], num_workers: int = None, ignore_error: bool = False
    ) -> list[JITKernel[_KP, _T]]:
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        """
        Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
        Parameters
        ----------
        configs : Iterable[Union[dict[str, Any], tuple[Any, ...]]]
            The configurations to elaborate and compile. Each config can be either
            a dictionary mapping keyword arguments to values, or a tuple of positional
            arguments.
        num_workers : int, optional
            Number of parallel workers to use for compilation. Defaults to None,
            which lets the system decide.
        ignore_error : bool, optional
            If True, compilation errors for individual configs will be logged
            as warnings and the corresponding result will be None. If False,
            any compilation error will raise an exception. Defaults to False.
        Returns
        -------
        List[JITKernel]
            A list of compiled JITKernel objects corresponding to the provided configs.
        """
325
326
        configs = list(configs)
        funcs = []
327
        for cfg in tqdm(configs, desc="Elaborating"):
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            if isinstance(cfg, tuple):
                funcs.append(self.get_tir(*cfg))
            elif isinstance(cfg, dict):
                funcs.append(self.get_tir(**cfg))
            else:
                raise ValueError(f"Invalid config type: {type(cfg)}, expected tuple or dict.")
        return par_compile(
            funcs,
            out_idx=self.out_idx,
            execution_backend=self.execution_backend,
            target=self.target,
            target_host=self.target_host,
            verbose=self.verbose,
            pass_configs=self.pass_configs,
            compile_flags=self.compile_flags,
            num_workers=num_workers,
344
345
            ignore_error=ignore_error,
        )
346

347
    def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        func = self.get_tir(*args, **kwargs)
        kernel_result = compile(
            func,
            out_idx=self.out_idx,
            execution_backend=self.execution_backend,
            target=self.target,
            target_host=self.target_host,
            verbose=self.verbose,
            pass_configs=self.pass_configs,
            compile_flags=self.compile_flags,
        )

        if self.debug_root_path:
            if isinstance(self.func, PrimFunc):
362
                func_name = self.func.attrs["global_symbol"]
363
            else:
364
365
366
                func_name = getattr(self.func, "__name__", "jit_kernel")
            kernel_file = f"tilelang_jit_kernel_{func_name}.c"
            program_file = f"tilelang_jit_program_{func_name}.py"
367
            makedirs(self.debug_root_path, exist_ok=True)
368
            with open(path.join(self.debug_root_path, kernel_file), "w") as f:
369
                print(kernel_result.get_kernel_source(), file=f)
370
            with open(path.join(self.debug_root_path, program_file), "w") as f:
371
372
373
374
                print(func.script(), file=f)

        return kernel_result

375
376
    def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs):
        if isinstance(self.func, PrimFuncCreater):
377
            tune_params = kwargs.pop("__tune_params", {})
378
379
            return self.func.func_annot.parse_key(*args, **kwargs, **tune_params)
        else:
380
            tune_params = kwargs.pop("__tune_params", {})
381
382
383
384
385
386
387
388
            key_args_tuple = args
            key_kwargs_tuple = tuple(sorted(kwargs.items()))
            tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
            key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
            return key

    def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs):
        if isinstance(self.func, PrimFuncCreater):
389
            tune_params = kwargs.pop("__tune_params", {})
390
391
            return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
        else:
392
            raise NotImplementedError("convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
393
394

    def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
395
396
        # Separate out the tuning parameters from the user's kwargs
        # Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
397
        return_compile_arguments = kwargs.pop("__return_compile_arguments", False)
398
        if return_compile_arguments:
399
            logger.warning("`__return_compile_arguments` is deprecated and will be removed in future versions.")
400
            compile_args = {
401
402
403
404
405
406
407
                "out_idx": self.out_idx,
                "execution_backend": self.execution_backend,
                "target": self.target,
                "target_host": self.target_host,
                "verbose": self.verbose,
                "pass_configs": self.pass_configs,
                "compile_flags": self.compile_flags,
408
409
410
            }
            return compile_args

411
        key = self.parse_cache_key(*args, **kwargs)
412

413
        tune_params = kwargs.pop("__tune_params", {})
414
415
416
417
418
419
420
421
422
423
424

        kernel = self._kernel_cache.get(key, None)
        if kernel is None:
            kernel = self.compile(*args, **kwargs, **tune_params)
            self._kernel_cache[key] = kernel

        if self.lazy_jit:
            args = self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
            return kernel(*args)
        else:
            return kernel
425

426

427
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
428
429
430


@overload
431
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]: ...
432
433
434
435
436
437
438
439


@overload
def jit(
    *,  # Indicates subsequent arguments are keyword-only
    out_idx: Any = None,
    target: str | Target = "auto",
    target_host: str | Target = None,
440
    execution_backend: ExecutionBackend = "auto",
441
442
443
    verbose: bool = False,
    pass_configs: dict[str, Any] | None = None,
    debug_root_path: str | None = None,
444
445
    compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]: ...
446
447
448


def jit(  # This is the new public interface
449
450
451
452
453
454
455
456
457
458
459
    func: Callable[_P, _T] | PrimFunc | None = None,
    *,  # Indicates subsequent arguments are keyword-only
    out_idx: Any = None,
    target: str | Target = "auto",
    target_host: str | Target = None,
    execution_backend: ExecutionBackend = "auto",
    verbose: bool = False,
    pass_configs: dict[str, Any] | None = None,
    debug_root_path: str | None = None,
    compile_flags: list[str] | str | None = None,
):
460
461
462
    """
    Just-In-Time (JIT) compiler decorator for TileLang functions.

463
    This decorator can be used without arguments (e.g., `@tilelang.jit`):
464
465
466
467
468
469
470
471
472
473
474
475
       Applies JIT compilation with default settings.

    Parameters
    ----------
    func_or_out_idx : Any, optional
        If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
        If using `@tilelang.jit` directly on a function, this argument is implicitly
        the function to be decorated (and `out_idx` will be `None`).
    target : Union[str, Target], optional
        Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
    target_host : Union[str, Target], optional
        Target host for cross-compilation. Defaults to None.
476
    execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
477
478
        Backend for kernel execution and argument passing. Use "auto" to pick a sensible
        default per target (cuda->tvm_ffi, metal->torch, others->cython).
479
480
481
482
483
484
485
486
487
488
489
490
491
    verbose : bool, optional
        Enables verbose logging during compilation. Defaults to False.
    pass_configs : Optional[Dict[str, Any]], optional
        Configurations for TVM's pass context. Defaults to None.
    debug_root_path : Optional[str], optional
        Directory to save compiled kernel source for debugging. Defaults to None.

    Returns
    -------
    Callable
        Either a JIT-compiled wrapper around the input function, or a configured decorator
        instance that can then be applied to a function.
    """
492

493
    def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
494
        if isinstance(func, (PrimFunc, PrimFuncCreater)):
495
496
497
498
            orig_func = func.orig_func
        else:
            orig_func = func
        return JITImpl(
499
            func=func,
500
            out_idx=out_idx,
501
502
503
504
505
            execution_backend=execution_backend,
            target=target,
            target_host=target_host,
            verbose=verbose,
            pass_configs=pass_configs,
506
            debug_root_path=debug_root_path,
507
508
509
            compile_flags=compile_flags,
            func_source=inspect.getsource(orig_func),
            signature=inspect.signature(orig_func),
510
511
            lazy_jit=False,
        )
512
513
514
515
516

    if func is not None:
        return decorator(func)
    else:
        return decorator
517
518
519


@overload
520
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ...
521
522
523
524
525
526
527
528
529
530
531
532


@overload
def lazy_jit(
    *,
    out_idx: Any = None,
    target: str | Target = "auto",
    target_host: str | Target = None,
    execution_backend: ExecutionBackend = "auto",
    verbose: bool = False,
    pass_configs: dict[str, Any] | None = None,
    debug_root_path: str | None = None,
533
534
    compile_flags: list[str] | str | None = None,
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]: ...
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555


def lazy_jit(
    func: Callable[_P, _T] | PrimFunc | None = None,
    *,  # Indicates subsequent arguments are keyword-only
    target: str | Target = "auto",
    target_host: str | Target = None,
    execution_backend: ExecutionBackend = "auto",
    verbose: bool = False,
    pass_configs: dict[str, Any] | None = None,
    debug_root_path: str | None = None,
    compile_flags: list[str] | str | None = None,
):
    compile_args = dict(
        out_idx=None,
        execution_backend=execution_backend,
        target=target,
        target_host=target_host,
        verbose=verbose,
        pass_configs=pass_configs,
        debug_root_path=debug_root_path,
556
557
        compile_flags=compile_flags,
    )
558
559
560
561
562
563
564
565

    def decorator(func: Callable[_P, _T]):
        pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True)
        # if isinstance(pf, PrimFunc):
        #     compile_args.pop('debug_root_path', None)
        #     return compile(pf, **compile_args)
        # else:
        return JITImpl(
566
567
            func=pf, **compile_args, func_source=inspect.getsource(pf.orig_func), signature=inspect.signature(pf.orig_func), lazy_jit=True
        )
568
569

    return decorator(func) if func is not None else decorator