kernel.py 15.8 KB
Newer Older
1
2
from typing import Any, Callable, Dict, List, Literal, Optional, Union

3
4
5
from tvm.target import Target
from tvm.tir import PrimFunc

6
7
8
9
10
import tilelang
from tilelang import tvm as tvm
from tilelang.engine.param import CompiledArtifact, KernelParam
from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter,
                                  NVRTCKernelAdapter, TorchDLPackKernelAdapter)
11
from tilelang.profiler import Profiler, TensorSupplyType
12
from tilelang.utils.target import AVALIABLE_TARGETS, determine_target
13
14
15
16
17
18
19
20


class JITKernel(object):
    """
    A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

    Attributes
    ----------
21
22
23
24
    artifact : CompiledArtifact
        The compiled artifact containing the runtime module and parameters.
    adapter : BaseKernelAdapter
        The adapter for the compiled function.
25
26
27
    torch_function : Callable
        The compiled function that can be invoked as a PyTorch-compatible function.
    """
28
    prim_func: PrimFunc = None
29
    artifact: CompiledArtifact = None
30
31
32
    adapter: BaseKernelAdapter = None
    torch_function: Callable = None

33
34
35
36
37
    # tuner result
    latency: float = None
    config: Dict[str, Any] = None
    ref_latency: float = None

38
39
40
41
    def __init__(
        self,
        func: PrimFunc = None,
        out_idx: Union[List[int], int] = None,
42
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
43
        target: Union[str, Target] = "auto",
44
        target_host: Union[str, Target] = None,
45
        verbose: bool = False,
46
        pass_configs: Optional[Dict[str, Any]] = None,
47
        from_database: bool = False,
48
49
50
51
52
53
54
55
56
57
    ):
        """
        Initializes a TorchFunction instance.

        Parameters
        ----------
        func : tvm.tir.PrimFunc, optional
            The TileLang TIR function to compile and wrap.
        out_idx : Union[List[int], int], optional
            Index(es) of the output tensors to return (default: None).
58
59
        execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
            Execution backend to use for kernel execution (default: "cython").
60
61
        target : Union[str, Target], optional
            Compilation target, either as a string or a TVM Target object (default: "auto").
62
63
        target_host : Union[str, Target], optional
            Target host for cross-compilation (default: None).
64
65
        verbose : bool, optional
            Whether to enable verbose output (default: False).
66
67
        pass_configs : dict, optional
            Additional keyword arguments to pass to the Compiler PassContext.
68
            Available options:
69
70
                "tir.disable_vectorize": bool, default: False
                "tl.disable_tma_lower": bool, default: False
71
72
                "tl.disable_dynamic_tail_split": bool, default: False
                "tl.dynamic_vectorize_size_bits": int, default: 128
73
74
        from_database : bool, optional
            Whether to create a TorchFunction from a database.
75
        """
76
        self.prim_func = func
77
78
        self.execution_backend = execution_backend
        self.target = target
79
        self.target_host = target_host
80
81
        self.verbose = verbose

82
83
84
85
        if pass_configs is None:
            pass_configs = {}
        self.pass_configs = pass_configs

86
87
88
89
90
91
92
93
94
        # If the target is specified as a string, validate it and convert it to a TVM Target.
        if isinstance(target, str):
            assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
            target = determine_target(target)

        # Ensure the target is always a TVM Target object.
        target = Target(target)

        # Validate the execution backend.
95
96
97
98
        assert execution_backend in [
            "dlpack",
            "ctypes",
            "cython",
99
            "nvrtc",
100
        ], f"Invalid execution backend. {execution_backend}"
101
102
        if execution_backend == "cython":
            from tilelang.contrib.cc import get_cplus_compiler
103
104
105
106
107
108
109

            assert (
                get_cplus_compiler() is not None
            ), "Cython backend requires a C++ compiler, please install or use other backends."

        if from_database:
            return
110
111

        # Compile the TileLang function and create a kernel adapter for execution.
112
        adapter = self._compile_and_create_adapter(func, out_idx)
113
114
115
116
117

        # The adapter's function is assigned as the callable function for this instance.
        self.adapter = adapter
        self.torch_function = adapter.func

118
119
120
121
122
123
124
125
126
127
    @classmethod
    def from_database(
        cls,
        func: PrimFunc,
        kernel_global_source: str,
        kernel_lib_path: str,
        params: List[KernelParam],
        target: Union[str, Target],
        target_host: Union[str, Target],
        out_idx: Union[List[int], int],
128
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        pass_configs: Optional[Dict[str, Any]] = None,
    ):
        """
        Alternative constructor to create a TorchFunction directly from a database.
        """
        instance = cls(
            func=func,
            out_idx=out_idx,
            execution_backend=execution_backend,
            target=target,
            target_host=target_host,
            pass_configs=pass_configs,
            from_database=True,
        )

        instance.adapter = instance._create_adapter_from_database(
            func_or_mod=func,
            params=params,
            result_idx=out_idx,
            target=target,
            kernel_global_source=kernel_global_source,
            kernel_lib_path=kernel_lib_path,
151
            pass_configs=pass_configs,
152
153
154
155
        )
        instance.torch_function = instance.adapter.func
        return instance

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        """
        Invokes the compiled function with the given arguments.

        Parameters
        ----------
        *args : Any
            Positional arguments for the function.
        **kwds : Any
            Keyword arguments for the function.

        Returns
        -------
        Any
            The result of the function execution.
        """
        return self.torch_function(*args, **kwds)

174
175
    def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
                                    out_idx: List[int]) -> BaseKernelAdapter:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        """
        Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.

        Parameters
        ----------
        tilelang_func : tvm.tir.PrimFunc
            The TileLang (TVM TIR) function to compile.

        Returns
        -------
        BaseKernelAdapter
            The compiled and ready-to-run kernel adapter.
        """
        verbose = self.verbose
        target = self.target
191
        target_host = self.target_host
192

193
        execution_backend = self.execution_backend
194
        pass_configs = self.pass_configs
195
196

        # Compile the function with TVM, optimizing with shared memory lowering.
197
198
        enable_host_codegen = execution_backend == "dlpack"
        enable_device_compile = execution_backend == "dlpack"
199
        with tvm.transform.PassContext(opt_level=3, config=pass_configs):
200
201
202
203
204
205
            artifact = tilelang.lower(
                tilelang_func,
                target=target,
                target_host=target_host,
                enable_host_codegen=enable_host_codegen,
                enable_device_compile=enable_device_compile)
206

207
        self.artifact = artifact
208
209

        # Create an adapter based on the specified execution backend.
210
        if execution_backend == "dlpack":
211
            # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
212
213
214
215
216
            # But we need to ensure that the runtime is enabled and the runtime module is not None.
            assert tvm.runtime.enabled("llvm"), "DLPack backend requires LLVM runtime."
            assert (artifact.rt_mod is not None), "DLPack backend requires a runtime module."
            adapter = TorchDLPackKernelAdapter(
                artifact.rt_mod, params=artifact.params, result_idx=out_idx)
217
        elif execution_backend == "ctypes":
218
            adapter = CtypesKernelAdapter(
219
                params=artifact.params,
220
221
222
                result_idx=out_idx,
                target=target,
                func_or_mod=tilelang_func,
223
224
225
                host_mod=artifact.host_mod,
                device_mod=artifact.device_mod,
                kernel_global_source=artifact.kernel_source,
226
                verbose=verbose,
227
                pass_configs=pass_configs,
228
            )
229
230
        elif execution_backend == "cython":
            adapter = CythonKernelAdapter(
231
                params=artifact.params,
232
233
234
                result_idx=out_idx,
                target=target,
                func_or_mod=tilelang_func,
235
236
237
                host_mod=artifact.host_mod,
                device_mod=artifact.device_mod,
                kernel_global_source=artifact.kernel_source,
238
                verbose=verbose,
239
                pass_configs=pass_configs,
240
            )
241
242
243
244
245
246
247
248
249
250
251
252
        elif execution_backend == "nvrtc":
            adapter = NVRTCKernelAdapter(
                params=artifact.params,
                result_idx=out_idx,
                target=target,
                func_or_mod=tilelang_func,
                host_mod=artifact.host_mod,
                device_mod=artifact.device_mod,
                kernel_global_source=artifact.kernel_source,
                verbose=verbose,
                pass_configs=pass_configs,
            )
253
254
255
256
257
258
        else:
            # Handle invalid backend.
            raise ValueError(f"Invalid execution backend: {execution_backend}")

        return adapter

259
260
261
262
263
264
265
266
    def _create_adapter_from_database(
        self,
        params: List[KernelParam],
        result_idx: Union[List[int], int],
        target: Union[str, Target],
        func_or_mod: Union[PrimFunc, tvm.runtime.Module],
        kernel_global_source: str,
        kernel_lib_path: str,
267
        pass_configs: Optional[Dict[str, Any]] = None,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    ) -> BaseKernelAdapter:
        target = self.target
        execution_backend = self.execution_backend

        # Create an adapter based on the specified execution backend.
        if execution_backend == "dlpack":
            raise ValueError("DLPack backend is not supported for TileLang JIT.")
        elif execution_backend == "ctypes":
            adapter = CtypesKernelAdapter.from_database(
                params=params,
                result_idx=result_idx,
                target=target,
                func_or_mod=func_or_mod,
                kernel_global_source=kernel_global_source,
                kernel_lib_path=kernel_lib_path,
283
                pass_configs=pass_configs,
284
285
286
287
288
289
290
291
292
            )
        elif execution_backend == "cython":
            adapter = CythonKernelAdapter.from_database(
                params=params,
                result_idx=result_idx,
                target=target,
                func_or_mod=func_or_mod,
                kernel_global_source=kernel_global_source,
                kernel_lib_path=kernel_lib_path,
293
                pass_configs=pass_configs,
294
            )
295
296
297
298
299
300
301
302
303
304
        elif execution_backend == "nvrtc":
            adapter = NVRTCKernelAdapter.from_database(
                params=params,
                result_idx=result_idx,
                target=target,
                func_or_mod=func_or_mod,
                kernel_global_source=kernel_global_source,
                kernel_lib_path=kernel_lib_path,
                pass_configs=pass_configs,
            )
305
306
307
308
309
310
        else:
            # Handle invalid backend.
            raise ValueError(f"Invalid execution backend: {execution_backend}")

        return adapter

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    @classmethod
    def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs):
        """
        Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.

        Parameters
        ----------
        tilelang_func : tvm.tir.PrimFunc
            The TileLang (TVM TIR) function to compile.
        **kwargs : dict
            Additional keyword arguments to pass to the constructor.

        Returns
        -------
        TorchFunction
            An instance of TorchFunction wrapping the compiled function.
        """
        return cls(func=tilelang_func, **kwargs)

    def get_profiler(self,
331
                     tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
332
333
334
335
336
337
        """
        Creates a profiler to benchmark the compiled runtime module.

        Parameters
        ----------
        tensor_supply_type : TensorSupplyType, optional
338
            The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
339
340
341
342
343
344

        Returns
        -------
        Profiler
            A Profiler instance for benchmarking the runtime module.
        """
345
346
        return Profiler(self.params, self.out_idx,
                        tensor_supply_type).with_default_adapter(self.adapter)
347
348
349
350
351
352
353
354
355
356

    def get_kernel_source(self) -> str:
        """
        Returns the source code of the compiled kernel function.

        Returns
        -------
        str
            The source code of the compiled kernel function.
        """
357
        if self.execution_backend in {"ctypes", "cython", "nvrtc"}:
358
            return self.adapter.get_kernel_source()
359
        return self.artifact.kernel_source
360

361
362
363
364
    def get_host_source(self) -> str:
        """
        Returns the source code of the host function.
        """
365
        return str(self.artifact.host_mod)
366

367
368
    def run_once(self, func: Optional[Callable] = None) -> None:
        return self.get_profiler().run_once(func)
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def update_tuner_result(self, latency: float, config: Dict[str, Any],
                            ref_latency: float) -> "JITKernel":
        """
        Updates the tuning results for this kernel.

        Parameters
        ----------
        latency : float
            The measured latency of this kernel configuration.
        config : Dict[str, Any]
            The configuration parameters used for this kernel.
        ref_latency : float
            The reference latency to compare against.

        Returns
        -------
        None
        """
        self.latency = latency
        self.config = config
        self.ref_latency = ref_latency

        return self

    def get_tuner_result(self) -> Dict[str, Any]:
        """
        Gets the tuning results for this kernel.

        Returns
        -------
        Dict[str, Any]
            A dictionary containing:
            - latency: The measured latency of this kernel
            - config: The configuration parameters used
            - ref_latency: The reference latency for comparison
        """
        if self.latency is None:
            raise ValueError("Tuning results are not available. Please tune the kernel first.")

        return {
            "latency": self.latency,
            "config": self.config,
            "ref_latency": self.ref_latency,
        }

415
416
417
418
    @property
    def out_idx(self) -> List[int]:
        return self.adapter.result_idx

419
420
421
422
423
424
425
426
427
428
429
430
    @property
    def params(self) -> List[KernelParam]:
        return self.artifact.params if self.artifact else self.adapter.params

    @property
    def kernel_source(self) -> str:
        return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source

    @property
    def host_source(self) -> str:
        return str(self.artifact.host_mod) if self.artifact else ""

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    def export_library(self, kernel_file: str) -> None:
        """
        Exports the compiled kernel function to a shared library file.

        Parameters
        ----------
        kernel_file : str
            The path to the shared library file to create.
        """
        # rt_module: tvm.runtime.Module = None
        # rt_params: dict = None
        # adapter: BaseKernelAdapter = None
        # torch_function: Callable = None
        # rt_module: use export_library to export
        # rt_params: use cloudpickle to serialize

        # Export the compiled kernel function to a shared library file.
        self.rt_module.export_library(kernel_file)