Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
"""The profiler and convert to torch utils"""
from dataclasses import dataclass
from typing import List, Union
from typing import List, Union, Optional
import torch
from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var
......@@ -82,3 +82,17 @@ class KernelParam:
bool: True if parameter is a float8 type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("float8")
@dataclass
class CompiledArtifact:
"""
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
"""
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel
kernel_source: str # Raw source code of the generated kernel
rt_mod: Optional[
tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized
......@@ -43,6 +43,9 @@ TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0]
TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR",
os.path.expanduser("~/.tilelang/cache"))
# Auto-clear cache if environment variable is set
TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0")
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
......@@ -84,6 +87,13 @@ else:
os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
# pip install build library path
lib_path = os.path.join(TILELANG_PACKAGE_PATH, "lib")
existing_path = os.environ.get("TVM_LIBRARY_PATH")
if existing_path:
os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}"
else:
os.environ["TVM_LIBRARY_PATH"] = lib_path
TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
if os.environ.get("TL_CUTLASS_PATH", None) is None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .utils import (
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType
from tvm.runtime import convert
import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.language as T
from typing import Tuple
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
from tvm import arith, DataType
import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable
from tilelang.common import TransformKind
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType
from typing import Literal
from .mma_layout import (
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from abc import ABC, abstractmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CtypesKernelAdapter # noqa: F401
......@@ -45,13 +45,14 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda')
......@@ -90,6 +91,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
......@@ -105,13 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: Optional[str] = None,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......@@ -136,16 +141,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.init()
adapter._post_init()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CythonKernelAdapter # noqa: F401
......@@ -156,13 +156,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda')
......@@ -191,6 +192,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
......@@ -212,17 +215,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
@classmethod
def from_database(cls,
rt_mod_src: str,
params: List[TensorType],
result_idx: List[int],
target,
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = rt_mod_src
adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......@@ -238,16 +243,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.buffer_device_map = adapter._process_buffer_device()
adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
try:
adapter.lib.init()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from .utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang import tvm as tvm
......@@ -27,8 +25,10 @@ class LibraryGenerator(object):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
def load_lib(self):
return ctypes.CDLL(self.libpath)
def load_lib(self, lib_path: Optional[str] = None):
if lib_path is None:
lib_path = self.libpath
return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None, with_tl: bool = True):
target = self.target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Union, Optional, Literal
from tilelang import tvm as tvm
......
......@@ -77,16 +77,21 @@ class TLCUDASourceWrapper(object):
backend = "tl"
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
......@@ -250,19 +255,20 @@ class TLCUDASourceWrapper(object):
return tma_descripter_init
def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
self.device_mod = device_mod
self.host_mod = host_mod
if self.device_mod is None or self.host_mod is None:
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
self.device_mod = device_mod
self.host_mod = host_mod
assert (len(self.device_mod.functions)
>= 1), "Device module should have at least one function."
assert (len(self.host_mod.functions) == 1), "Only support one function in host module."
block_info_map = {}
grid_info_map = {}
dynamic_smem_buf_map = {}
function_names = []
for g_var, func in device_mod.functions.items():
for g_var, func in self.device_mod.functions.items():
# Default block and grid configurations
block_info = [1, 1, 1]
grid_info = [1, 1, 1]
......@@ -291,7 +297,7 @@ class TLCUDASourceWrapper(object):
self.dynamic_smem_buf = dynamic_smem_buf_map
function_names_index = {}
for _, func in host_mod.functions.items():
for _, func in self.host_mod.functions.items():
if "tma_descriptor_args" in func.attrs:
self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
host_code = str(func)
......@@ -369,13 +375,18 @@ class TLCUDASourceWrapper(object):
class TLHIPSourceWrapper(TLCUDASourceWrapper):
"""
A wrapper class for the TileLang HIP backend.
"""
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, pass_configs)
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call
......@@ -418,16 +429,22 @@ class TLCPUSourceWrapper(object):
""")
backend = "tl"
backend = "tl"
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.device_mod = device_mod
self.host_mod = host_mod
self.pass_configs = pass_configs
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
......@@ -563,6 +580,14 @@ class TLCPUSourceWrapper(object):
class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
target: Optional[Target] = None
lib: Optional[object] = None
def __init__(self, target: Target):
super().__init__()
......@@ -577,6 +602,12 @@ class TLWrapper(BaseWrapper):
def assign_pass_configs(self, pass_configs: Dict[str, Any]):
self.pass_configs = pass_configs
def assign_host_module(self, host_mod: IRModule):
self.host_mod = host_mod
def assign_device_module(self, device_mod: IRModule):
self.device_mod = device_mod
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
......@@ -588,5 +619,11 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCPUSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target, self.pass_configs)
wrapper = wrapper_class(
scheduled_ir_module=self.scheduled_ir_module,
source=c_source,
target=self.target,
device_mod=self.device_mod,
host_mod=self.host_mod,
pass_configs=self.pass_configs)
return wrapper.lib_code
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......
......@@ -4,10 +4,15 @@ import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter
from tilelang.jit.adapter import (
TorchDLPackKernelAdapter,
BaseKernelAdapter,
CtypesKernelAdapter,
CythonKernelAdapter,
)
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.engine.param import KernelParam
from tilelang.engine.param import KernelParam, CompiledArtifact
class JITKernel(object):
......@@ -16,15 +21,15 @@ class JITKernel(object):
Attributes
----------
rt_mod : tvm.runtime.Module
The runtime module compiled by TVM.
params : List[KernelParam]
Parameters for the compiled runtime module (e.g., weights or constants).
artifact : CompiledArtifact
The compiled artifact containing the runtime module and parameters.
adapter : BaseKernelAdapter
The adapter for the compiled function.
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
rt_mod: tvm.runtime.Module = None
params: List[KernelParam] = None
artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None
torch_function: Callable = None
......@@ -37,8 +42,7 @@ class JITKernel(object):
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
rt_module_src: Optional[str] = None,
rt_params: dict = None,
from_database: bool = False,
):
"""
Initializes a TorchFunction instance.
......@@ -59,9 +63,11 @@ class JITKernel(object):
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
from_database : bool, optional
Whether to create a TorchFunction from a database.
"""
self.out_idx = out_idx
self.execution_backend = execution_backend
......@@ -73,42 +79,6 @@ class JITKernel(object):
pass_configs = {}
self.pass_configs = pass_configs
if rt_module_src and rt_params:
self.rt_mod = None
self.params = rt_params
adapter = None
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
# assert dlpack not supported
raise ValueError(f"Invalid execution backend: {execution_backend}")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database(
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
kernel_global_source=rt_module_src,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
rt_mod_src=rt_module_src,
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
verbose=verbose,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
self.adapter = adapter
self.torch_function = adapter.func
return
# If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
......@@ -118,12 +88,20 @@ class JITKernel(object):
target = Target(target)
# Validate the execution backend.
assert execution_backend in ["dlpack", "ctypes",
"cython"], f"Invalid execution backend. {execution_backend}"
assert execution_backend in [
"dlpack",
"ctypes",
"cython",
], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
assert get_cplus_compiler(
) is not None, "Cython backend requires a C++ compiler, please install or use other backends."
assert (
get_cplus_compiler() is not None
), "Cython backend requires a C++ compiler, please install or use other backends."
if from_database:
return
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func)
......@@ -132,6 +110,43 @@ class JITKernel(object):
self.adapter = adapter
self.torch_function = adapter.func
@classmethod
def from_database(
cls,
func: PrimFunc,
kernel_global_source: str,
kernel_lib_path: str,
params: List[KernelParam],
target: Union[str, Target],
target_host: Union[str, Target],
out_idx: Union[List[int], int],
execution_backend: Literal["dlpack", "ctypes", "cython"],
pass_configs: Optional[Dict[str, Any]] = None,
):
"""
Alternative constructor to create a TorchFunction directly from a database.
"""
instance = cls(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
pass_configs=pass_configs,
from_database=True,
)
instance.adapter = instance._create_adapter_from_database(
func_or_mod=func,
params=params,
result_idx=out_idx,
target=target,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
instance.torch_function = instance.adapter.func
return instance
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""
Invokes the compiled function with the given arguments.
......@@ -173,37 +188,39 @@ class JITKernel(object):
# Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3, config=pass_configs):
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
artifact = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use.
self.rt_mod = rt_mod
self.params = params
self.artifact = artifact
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
# But we need to ensure that the runtime is enabled and the runtime module is not None.
assert tvm.runtime.enabled("llvm"), "DLPack backend requires LLVM runtime."
assert (artifact.rt_mod is not None), "DLPack backend requires a runtime module."
adapter = TorchDLPackKernelAdapter(
artifact.rt_mod, params=artifact.params, result_idx=out_idx)
elif execution_backend == "ctypes":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CtypesKernelAdapter(
params=params,
params=artifact.params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CythonKernelAdapter(
params=params,
params=artifact.params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose,
pass_configs=pass_configs,
)
......@@ -213,6 +230,45 @@ class JITKernel(object):
return adapter
def _create_adapter_from_database(
self,
params: List[KernelParam],
result_idx: Union[List[int], int],
target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
raise ValueError("DLPack backend is not supported for TileLang JIT.")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
return adapter
@classmethod
def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs):
"""
......@@ -261,17 +317,29 @@ class JITKernel(object):
"""
if self.execution_backend in {"ctypes", "cython"}:
return self.adapter.get_kernel_source()
return self.rt_mod.imported_modules[0].get_source()
return self.artifact.kernel_source
def get_host_source(self) -> str:
"""
Returns the source code of the host function.
"""
return self.rt_mod.get_source()
return str(self.artifact.host_mod)
def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func)
@property
def params(self) -> List[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params
@property
def kernel_source(self) -> str:
return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source
@property
def host_source(self) -> str:
return str(self.artifact.host_mod) if self.artifact else ""
def export_library(self, kernel_file: str) -> None:
"""
Exports the compiled kernel function to a shared library file.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment