"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c425d1a7232ac843dd6d9cb65e3cd56ecc47a85d"
Commit f2e99180 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout LLVM Dependency by Making it Optional (#247)

* remove llvm build

* [Refactor] Update kernel compilation and profiling in examples

- Replaced `tilelang.lower` with `tilelang.compile` in multiple example scripts to streamline kernel compilation.
- Updated profiling calls to utilize the new `get_profiler` method, enhancing performance measurement consistency.
- Adjusted assertions and benchmarking methods to align with the new profiling structure across various examples, ensuring correctness and clarity in performance evaluations.

* lint fix

* License Update

* [Refactor] Improve code formatting and documentation in CUDA header and HIP runtime files

- Adjusted formatting in `cuda.h` for better readability, including alignment of comments and struct fields.
- Cleaned up whitespace and improved comment clarity in `rt_mod_hip.cc` to enhance code maintainability.

* [Refactor] Enhance formatting and clarity in CUDA header and HIP runtime files

- Improved comment alignment and readability in `cuda.h`.
- Cleaned up whitespace and formatting in `rt_mod_hip.cc` to enhance maintainability.

* lint fix

* lint fix

* lint fix

* lint fix

* fix

* License update

* [Enhancement] Update JITKernel to use artifact for kernel source

- Assigned the generated artifact to `self.artifact` for better management.
- Updated kernel source references to use `artifact.kernel_source` for consistency in execution backend handling.

* lint fix

* Add @tilelang.testing.requires_llvm decorator to vectorization tests

* Enhance setup.py and env.py for library management

- Added functionality to remove original files after copying in CMakeBuild.
- Updated TVM_LIBRARY_PATH in env.py to include the PyPI build library path for better integration.

* Refactor TVM_LIBRARY_PATH assignment for improved readability in env.py

* Refactor CMakeBuild file handling in setup.py

- Added a check to ensure the target library directory exists before copying .so files.
- Improved the logic for creating the target directory and copying files to enhance robustness.

* bugfix

* Rename BuildTLDebug to BuildTileLangCUDAWithoutCompile and update registration. Add @tilelang.testing.requires_llvm decorator to multiple tests for LLVM requirement.

* lint fix

* Enhance TileLang code generation by adding support for device code generation without compilation. Updated `host_codegen` and `device_codegen` functions to include new transformations and registration for `tilelang_hip_without_compile`. Refactored JIT kernel adapters to accommodate host and device modules, improving overall integration and flexibility.

* lint fix

* Add support for C target in device code generation

- Updated `device_codegen_without_compile` to include handling for the C target by registering the `tilelang_cpp` function.

* [Enhancement] Implement auto-clear cache feature based on environment variable

* Added TILELANG_CLEAR_CACHE environment variable to control cache clearing.
* Updated CI workflow to set TILELANG_CLEAR_CACHE during testing.
* Modified cache initialization to clear cache if TILELANG_CLEAR_CACHE is set to true.

* [Refactor] Update kernel invocation and import paths in tests and cache

* Changed kernel invocation in `test_tilelang_kernel_dequantize_gemm.py` to return the result.
* Updated import statements in `test_tilelang_kernel_int4_gemm_mma.py` to use `bitblas` instead of `tilelang`.
* Refactored paths for artifact and parameters in `kernel_cache.py` for better maintainability.

* [Refactor] Clean up whitespace and improve code formatting in kernel_cache.py

* Removed unnecessary blank lines and adjusted spacing for better readability in the KernelCache class.
* Enhanced overall code formatting to align with project standards.

* [Enhancement] Add bfloat16 test case and improve kernel caching logic

* Introduced a new test case for bfloat16 matrix multiplication in `test_tilelang_kernel_gemm_mma_intrinsic.py`.
* Updated `KernelCache` to handle multiple kernel source files and improve error handling during saving and loading.
* Refactored `JITKernel` to support instantiation from a database, enhancing flexibility in kernel management.
* Adjusted `CtypesKernelAdapter` and `CythonKernelAdapter` to utilize the new kernel loading mechanism from the database.
* Improved code formatting and readability across several files.

* lint fix

* Update bfloat16 matrix multiplication test case to use larger dimensions for improved coverage
parent 43bd9d3e
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Union, Optional
import torch import torch
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var from tvm.tir import Buffer, IntImm, Var
...@@ -82,3 +82,17 @@ class KernelParam: ...@@ -82,3 +82,17 @@ class KernelParam:
bool: True if parameter is a float8 type, False otherwise bool: True if parameter is a float8 type, False otherwise
""" """
return str(self.dtype).removeprefix("torch.").startswith("float8") return str(self.dtype).removeprefix("torch.").startswith("float8")
@dataclass
class CompiledArtifact:
"""
Represents a compiled kernel artifact containing both host and device code.
Stores all necessary components for kernel execution in the TVM runtime.
"""
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel
kernel_source: str # Raw source code of the generated kernel
rt_mod: Optional[
tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized
...@@ -43,6 +43,9 @@ TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0] ...@@ -43,6 +43,9 @@ TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0]
TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR", TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR",
os.path.expanduser("~/.tilelang/cache")) os.path.expanduser("~/.tilelang/cache"))
# Auto-clear cache if environment variable is set
TILELANG_CLEAR_CACHE = os.environ.get("TILELANG_CLEAR_CACHE", "0")
# SETUP ENVIRONMENT VARIABLES # SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
...@@ -84,6 +87,13 @@ else: ...@@ -84,6 +87,13 @@ else:
os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
else: else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE) logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
# pip install build library path
lib_path = os.path.join(TILELANG_PACKAGE_PATH, "lib")
existing_path = os.environ.get("TVM_LIBRARY_PATH")
if existing_path:
os.environ["TVM_LIBRARY_PATH"] = f"{existing_path}:{lib_path}"
else:
os.environ["TVM_LIBRARY_PATH"] = lib_path
TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None) TVM_LIBRARY_PATH = os.environ.get("TVM_LIBRARY_PATH", None)
if os.environ.get("TL_CUTLASS_PATH", None) is None: if os.environ.get("TL_CUTLASS_PATH", None) is None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .utils import ( from .utils import (
mma_store_index_map, # noqa: F401 mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401 get_ldmatrix_offset, # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType from tvm import DataType
from tvm.runtime import convert from tvm.runtime import convert
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
from typing import Tuple from typing import Tuple
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union from typing import Union
from tvm import arith, DataType from tvm import arith, DataType
import tilelang.language as T import tilelang.language as T
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang.language as T import tilelang.language as T
from typing import Union, Tuple, Optional, Literal, Callable from typing import Union, Tuple, Optional, Literal, Callable
from tilelang.common import TransformKind from tilelang.common import TransformKind
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import DataType from tvm import DataType
from typing import Literal from typing import Literal
from .mma_layout import ( from .mma_layout import (
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .base import BaseKernelAdapter # noqa: F401 from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401 from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401 from .ctypes import CtypesKernelAdapter # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CtypesKernelAdapter # noqa: F401 from .adapter import CtypesKernelAdapter # noqa: F401
...@@ -45,13 +45,14 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -45,13 +45,14 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int], result_idx: List[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None, kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda') target: Target platform (e.g., 'cuda')
...@@ -90,6 +91,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -90,6 +91,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source) self.lib_generator.update_lib_code(self.wrapped_source)
...@@ -105,13 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -105,13 +108,15 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int], result_idx: List[int],
target: str, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: Optional[str] = None, kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -136,16 +141,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -136,16 +141,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.target = Target.canon_target(determine_target(target)) adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target) adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
adapter.lib.init() adapter.lib.init()
adapter._post_init() adapter._post_init()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .adapter import CythonKernelAdapter # noqa: F401 from .adapter import CythonKernelAdapter # noqa: F401
...@@ -156,13 +156,14 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -156,13 +156,14 @@ class CythonKernelAdapter(BaseKernelAdapter):
result_idx: List[int], result_idx: List[int],
target: Union[str, Target], target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str, host_mod: Optional[tvm.IRModule] = None,
device_mod: Optional[tvm.IRModule] = None,
kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
rt_mod: Runtime module
params: List of tensor types for inputs/outputs params: List of tensor types for inputs/outputs
result_idx: Indices of output tensors result_idx: Indices of output tensors
target: Target platform (e.g., 'cuda') target: Target platform (e.g., 'cuda')
...@@ -191,6 +192,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -191,6 +192,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs) self.wrapper.assign_pass_configs(pass_configs)
self.wrapper.assign_host_module(host_mod)
self.wrapper.assign_device_module(device_mod)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source) self.lib_generator.update_lib_code(self.wrapped_source)
...@@ -212,17 +215,19 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -212,17 +215,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
@classmethod @classmethod
def from_database(cls, def from_database(cls,
rt_mod_src: str,
params: List[TensorType], params: List[TensorType],
result_idx: List[int], result_idx: List[int],
target, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls) adapter = cls.__new__(cls)
adapter.params = params adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = rt_mod_src adapter.kernel_global_source = kernel_global_source
adapter.wrapped_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -238,16 +243,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -238,16 +243,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.buffer_device_map = adapter._process_buffer_device() adapter.buffer_device_map = adapter._process_buffer_device()
adapter.verbose = verbose adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target) adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
try: try:
adapter.lib.init() adapter.lib.init()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
import torch import torch
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional from typing import Optional
from .utils import is_cuda_target, is_hip_target, is_cpu_target from .utils import is_cuda_target, is_hip_target, is_cpu_target
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -27,8 +25,10 @@ class LibraryGenerator(object): ...@@ -27,8 +25,10 @@ class LibraryGenerator(object):
self.lib_code = lib_code self.lib_code = lib_code
# Assume currently we only support CUDA compilation # Assume currently we only support CUDA compilation
def load_lib(self): def load_lib(self, lib_path: Optional[str] = None):
return ctypes.CDLL(self.libpath) if lib_path is None:
lib_path = self.libpath
return ctypes.CDLL(lib_path)
def compile_lib(self, timeout: float = None, with_tl: bool = True): def compile_lib(self, timeout: float = None, with_tl: bool = True):
target = self.target target = self.target
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re import re
from typing import Union, Optional, Literal from typing import Union, Optional, Literal
from tilelang import tvm as tvm from tilelang import tvm as tvm
......
...@@ -77,16 +77,21 @@ class TLCUDASourceWrapper(object): ...@@ -77,16 +77,21 @@ class TLCUDASourceWrapper(object):
backend = "tl" backend = "tl"
device_mod: Optional[IRModule] = None device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
self.function_names: Optional[str] = None self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1] self.block_info: Union[List[int], Dict] = [1, 1, 1]
...@@ -250,19 +255,20 @@ class TLCUDASourceWrapper(object): ...@@ -250,19 +255,20 @@ class TLCUDASourceWrapper(object):
return tma_descripter_init return tma_descripter_init
def parse_source_information(self): def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): if self.device_mod is None or self.host_mod is None:
device_mod, host_mod = get_annotated_mod(self.mod, self.target) with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function." self.device_mod = device_mod
assert (len(host_mod.functions) == 1), "Only support one function in host module." self.host_mod = host_mod
self.device_mod = device_mod assert (len(self.device_mod.functions)
self.host_mod = host_mod >= 1), "Device module should have at least one function."
assert (len(self.host_mod.functions) == 1), "Only support one function in host module."
block_info_map = {} block_info_map = {}
grid_info_map = {} grid_info_map = {}
dynamic_smem_buf_map = {} dynamic_smem_buf_map = {}
function_names = [] function_names = []
for g_var, func in device_mod.functions.items(): for g_var, func in self.device_mod.functions.items():
# Default block and grid configurations # Default block and grid configurations
block_info = [1, 1, 1] block_info = [1, 1, 1]
grid_info = [1, 1, 1] grid_info = [1, 1, 1]
...@@ -291,7 +297,7 @@ class TLCUDASourceWrapper(object): ...@@ -291,7 +297,7 @@ class TLCUDASourceWrapper(object):
self.dynamic_smem_buf = dynamic_smem_buf_map self.dynamic_smem_buf = dynamic_smem_buf_map
function_names_index = {} function_names_index = {}
for _, func in host_mod.functions.items(): for _, func in self.host_mod.functions.items():
if "tma_descriptor_args" in func.attrs: if "tma_descriptor_args" in func.attrs:
self.tma_descriptor_args = func.attrs["tma_descriptor_args"] self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
host_code = str(func) host_code = str(func)
...@@ -369,13 +375,18 @@ class TLCUDASourceWrapper(object): ...@@ -369,13 +375,18 @@ class TLCUDASourceWrapper(object):
class TLHIPSourceWrapper(TLCUDASourceWrapper): class TLHIPSourceWrapper(TLCUDASourceWrapper):
"""
A wrapper class for the TileLang HIP backend.
"""
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, pass_configs) super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_hip_init_func(self): def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call # Initialize an empty string for the CUDA function call
...@@ -418,16 +429,22 @@ class TLCPUSourceWrapper(object): ...@@ -418,16 +429,22 @@ class TLCPUSourceWrapper(object):
""") """)
backend = "tl" backend = "tl"
backend = "tl" device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
source: str, source: str,
target: Target, target: Target,
device_mod: Optional[IRModule] = None,
host_mod: Optional[IRModule] = None,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.device_mod = device_mod
self.host_mod = host_mod
self.pass_configs = pass_configs self.pass_configs = pass_configs
self.function_names: Optional[str] = None self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: Optional[int] = None
...@@ -563,6 +580,14 @@ class TLCPUSourceWrapper(object): ...@@ -563,6 +580,14 @@ class TLCPUSourceWrapper(object):
class TLWrapper(BaseWrapper): class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
pass_configs: Optional[Dict[str, Any]] = None
target: Optional[Target] = None
lib: Optional[object] = None
def __init__(self, target: Target): def __init__(self, target: Target):
super().__init__() super().__init__()
...@@ -577,6 +602,12 @@ class TLWrapper(BaseWrapper): ...@@ -577,6 +602,12 @@ class TLWrapper(BaseWrapper):
def assign_pass_configs(self, pass_configs: Dict[str, Any]): def assign_pass_configs(self, pass_configs: Dict[str, Any]):
self.pass_configs = pass_configs self.pass_configs = pass_configs
def assign_host_module(self, host_mod: IRModule):
self.host_mod = host_mod
def assign_device_module(self, device_mod: IRModule):
self.device_mod = device_mod
# Get Scheduled Rt Module and return source to be compiled # Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str): def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first." assert self.scheduled_ir_module is not None, "Please assign optimized module first."
...@@ -588,5 +619,11 @@ class TLWrapper(BaseWrapper): ...@@ -588,5 +619,11 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCPUSourceWrapper wrapper_class = TLCPUSourceWrapper
else: else:
raise ValueError(f"Unsupported platform: {self.arch.platform}") raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target, self.pass_configs) wrapper = wrapper_class(
scheduled_ir_module=self.scheduled_ir_module,
source=c_source,
target=self.target,
device_mod=self.device_mod,
host_mod=self.host_mod,
pass_configs=self.pass_configs)
return wrapper.lib_code return wrapper.lib_code
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information # distributed with this work for additional information
......
...@@ -4,10 +4,15 @@ import tilelang ...@@ -4,10 +4,15 @@ import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter from tilelang.jit.adapter import (
TorchDLPackKernelAdapter,
BaseKernelAdapter,
CtypesKernelAdapter,
CythonKernelAdapter,
)
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam, CompiledArtifact
class JITKernel(object): class JITKernel(object):
...@@ -16,15 +21,15 @@ class JITKernel(object): ...@@ -16,15 +21,15 @@ class JITKernel(object):
Attributes Attributes
---------- ----------
rt_mod : tvm.runtime.Module artifact : CompiledArtifact
The runtime module compiled by TVM. The compiled artifact containing the runtime module and parameters.
params : List[KernelParam] adapter : BaseKernelAdapter
Parameters for the compiled runtime module (e.g., weights or constants). The adapter for the compiled function.
torch_function : Callable torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function. The compiled function that can be invoked as a PyTorch-compatible function.
""" """
rt_mod: tvm.runtime.Module = None
params: List[KernelParam] = None artifact: CompiledArtifact = None
adapter: BaseKernelAdapter = None adapter: BaseKernelAdapter = None
torch_function: Callable = None torch_function: Callable = None
...@@ -37,8 +42,7 @@ class JITKernel(object): ...@@ -37,8 +42,7 @@ class JITKernel(object):
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
rt_module_src: Optional[str] = None, from_database: bool = False,
rt_params: dict = None,
): ):
""" """
Initializes a TorchFunction instance. Initializes a TorchFunction instance.
...@@ -59,9 +63,11 @@ class JITKernel(object): ...@@ -59,9 +63,11 @@ class JITKernel(object):
Whether to enable verbose output (default: False). Whether to enable verbose output (default: False).
pass_configs : dict, optional pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext. Additional keyword arguments to pass to the Compiler PassContext.
Available options: Available options:
"tir.disable_vectorize": bool, default: False "tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False "tl.disable_tma_lower": bool, default: False
from_database : bool, optional
Whether to create a TorchFunction from a database.
""" """
self.out_idx = out_idx self.out_idx = out_idx
self.execution_backend = execution_backend self.execution_backend = execution_backend
...@@ -73,42 +79,6 @@ class JITKernel(object): ...@@ -73,42 +79,6 @@ class JITKernel(object):
pass_configs = {} pass_configs = {}
self.pass_configs = pass_configs self.pass_configs = pass_configs
if rt_module_src and rt_params:
self.rt_mod = None
self.params = rt_params
adapter = None
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
# assert dlpack not supported
raise ValueError(f"Invalid execution backend: {execution_backend}")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database(
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
kernel_global_source=rt_module_src,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
rt_mod_src=rt_module_src,
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
verbose=verbose,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
self.adapter = adapter
self.torch_function = adapter.func
return
# If the target is specified as a string, validate it and convert it to a TVM Target. # If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str): if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}" assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
...@@ -118,12 +88,20 @@ class JITKernel(object): ...@@ -118,12 +88,20 @@ class JITKernel(object):
target = Target(target) target = Target(target)
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in ["dlpack", "ctypes", assert execution_backend in [
"cython"], f"Invalid execution backend. {execution_backend}" "dlpack",
"ctypes",
"cython",
], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython": if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler from tilelang.contrib.cc import get_cplus_compiler
assert get_cplus_compiler(
) is not None, "Cython backend requires a C++ compiler, please install or use other backends." assert (
get_cplus_compiler() is not None
), "Cython backend requires a C++ compiler, please install or use other backends."
if from_database:
return
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func) adapter = self._compile_and_create_adapter(func)
...@@ -132,6 +110,43 @@ class JITKernel(object): ...@@ -132,6 +110,43 @@ class JITKernel(object):
self.adapter = adapter self.adapter = adapter
self.torch_function = adapter.func self.torch_function = adapter.func
@classmethod
def from_database(
cls,
func: PrimFunc,
kernel_global_source: str,
kernel_lib_path: str,
params: List[KernelParam],
target: Union[str, Target],
target_host: Union[str, Target],
out_idx: Union[List[int], int],
execution_backend: Literal["dlpack", "ctypes", "cython"],
pass_configs: Optional[Dict[str, Any]] = None,
):
"""
Alternative constructor to create a TorchFunction directly from a database.
"""
instance = cls(
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
pass_configs=pass_configs,
from_database=True,
)
instance.adapter = instance._create_adapter_from_database(
func_or_mod=func,
params=params,
result_idx=out_idx,
target=target,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
instance.torch_function = instance.adapter.func
return instance
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any:
""" """
Invokes the compiled function with the given arguments. Invokes the compiled function with the given arguments.
...@@ -173,37 +188,39 @@ class JITKernel(object): ...@@ -173,37 +188,39 @@ class JITKernel(object):
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3, config=pass_configs): with tvm.transform.PassContext(opt_level=3, config=pass_configs):
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host) artifact = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use. self.artifact = artifact
self.rt_mod = rt_mod
self.params = params
# Create an adapter based on the specified execution backend. # Create an adapter based on the specified execution backend.
if execution_backend == "dlpack": if execution_backend == "dlpack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx) # But we need to ensure that the runtime is enabled and the runtime module is not None.
assert tvm.runtime.enabled("llvm"), "DLPack backend requires LLVM runtime."
assert (artifact.rt_mod is not None), "DLPack backend requires a runtime module."
adapter = TorchDLPackKernelAdapter(
artifact.rt_mod, params=artifact.params, result_idx=out_idx)
elif execution_backend == "ctypes": elif execution_backend == "ctypes":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CtypesKernelAdapter( adapter = CtypesKernelAdapter(
params=params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source, host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
elif execution_backend == "cython": elif execution_backend == "cython":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CythonKernelAdapter( adapter = CythonKernelAdapter(
params=params, params=artifact.params,
result_idx=out_idx, result_idx=out_idx,
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source, host_mod=artifact.host_mod,
device_mod=artifact.device_mod,
kernel_global_source=artifact.kernel_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
...@@ -213,6 +230,45 @@ class JITKernel(object): ...@@ -213,6 +230,45 @@ class JITKernel(object):
return adapter return adapter
def _create_adapter_from_database(
self,
params: List[KernelParam],
result_idx: Union[List[int], int],
target: Union[str, Target],
func_or_mod: Union[PrimFunc, tvm.runtime.Module],
kernel_global_source: str,
kernel_lib_path: str,
) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
raise ValueError("DLPack backend is not supported for TileLang JIT.")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
kernel_global_source=kernel_global_source,
kernel_lib_path=kernel_lib_path,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
return adapter
@classmethod @classmethod
def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs): def from_tilelang_function(cls, tilelang_func: PrimFunc, **kwargs):
""" """
...@@ -261,17 +317,29 @@ class JITKernel(object): ...@@ -261,17 +317,29 @@ class JITKernel(object):
""" """
if self.execution_backend in {"ctypes", "cython"}: if self.execution_backend in {"ctypes", "cython"}:
return self.adapter.get_kernel_source() return self.adapter.get_kernel_source()
return self.rt_mod.imported_modules[0].get_source() return self.artifact.kernel_source
def get_host_source(self) -> str: def get_host_source(self) -> str:
""" """
Returns the source code of the host function. Returns the source code of the host function.
""" """
return self.rt_mod.get_source() return str(self.artifact.host_mod)
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
@property
def params(self) -> List[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params
@property
def kernel_source(self) -> str:
return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source
@property
def host_source(self) -> str:
return str(self.artifact.host_mod) if self.artifact else ""
def export_library(self, kernel_file: str) -> None: def export_library(self, kernel_file: str) -> None:
""" """
Exports the compiled kernel function to a shared library file. Exports the compiled kernel function to a shared library file.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment