Commit e789808b authored by alex_xiao's avatar alex_xiao Committed by LeiWang1999
Browse files

[Feature] Add database storage for JITKernel cache with Cython and Ctypes adapters (#213)



* [Dev] Add database mechanism to cache

* [Dev] Fix database cache and test for it

* [Dev] Refactor env.py to use TILELANG_CACHE_DIR and remove extra comment.

* [Refactor] Improve code formatting and readability in multiple files

* [Enhancement] Add execution backend options and improve kernel adapter initialization

* [Refactor] Rename cached function to cached_kernel and update related references

* [Enhancement] Enable target and target_host parameters in kernel loading and improve gemm test case

* [Enhancement] Update kernel compilation to specify execution backend as "cython"

* [Refactor] Rename cached_kernel to cached and update references in the codebase

* [Enhancement] Un-comment and add test cases for matrix multiplication correctness; improve kernel caching logic and remove redundant code

* [Refactor] Clean up code formatting and improve readability in cache and adapter modules

* [Refactor] Remove unused imports

* [Refactor] Update cached function signature to use PrimFunc and Optional types for improved type safety

* [Refactor] Update cached function calls to use PrimFunc and improve parameter handling

* [Refactor] Clean up import statements and improve code formatting in cache and kernel test files

* Update tilelang/jit/kernel.py

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent efceb6ed
...@@ -4,120 +4,100 @@ ...@@ -4,120 +4,100 @@
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tilelang import cached from tilelang import cached
import tilelang.language as T
def matmul( def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
M, """
N, Defines a matrix multiplication primitive function using tilelang.
K,
block_M, This function constructs a tilelang primitive function for matrix multiplication,
block_N, optimized for execution on hardware accelerators. It utilizes shared memory and
block_K, fragment memory for performance.
trans_A,
trans_B, Args:
in_dtype, M (int): Number of rows in matrix A and C.
out_dtype, N (int): Number of columns in matrix B and C.
accum_dtype, K (int): Number of columns in matrix A and rows in matrix B.
num_stages, block_M (int): Block size for M dimension in shared memory and fragment.
threads, block_N (int): Block size for N dimension in shared memory and fragment.
): block_K (int): Block size for K dimension in shared memory.
A_shape = (K, M) if trans_A else (M, K) dtype (str, optional): Data type for input matrices A and B, and output C. Defaults to "float16".
B_shape = (N, K) if trans_B else (K, N) accum_dtype (str, optional): Accumulation data type for internal computations. Defaults to "float".
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) Returns:
T.PrimFunc: A tilelang primitive function representing the matrix multiplication.
import tilelang.language as T """
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
if trans_A: T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
else: T.gemm(A_shared, B_shared, C_local)
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
def run_cache_matmul( def run_cache_matmul():
M, """
N, Demonstrates the usage of the cached matrix multiplication kernel.
K,
trans_A, This function defines a reference PyTorch matrix multiplication,
trans_B, creates a cached kernel from the tilelang matmul function,
in_dtype, runs the kernel with random input tensors, compares the output with the reference,
out_dtype, and prints the CUDA kernel source code.
dtypeAccum, """
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = cached(program, [2])
profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
"""
Reference PyTorch matrix multiplication for comparison.
"""
import torch import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.half) # Assuming dtype="float16" in matmul
return C return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) func = matmul(1024, 1024, 1024, 128, 128, 32)
kernel = cached(func, [2], execution_backend="cython")
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
print("\nOutput from Cached Kernel:")
print(c)
ref_c = ref_program(a, b)
print("\nReference PyTorch Output:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("\nOutputs are close (within tolerance).")
# Get CUDA Source
print("\nCUDA Kernel Source:")
print(kernel.get_kernel_source())
def test_cache_matmul_f16f16f16_nn(): def test_cache_matmul_f16f16f16_nn():
run_cache_matmul( """
512, Test function for cached matrix multiplication (float16 inputs, float16 output, no transpose).
1024, """
768, run_cache_matmul()
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -235,7 +235,7 @@ class _attention(torch.autograd.Function): ...@@ -235,7 +235,7 @@ class _attention(torch.autograd.Function):
BATCH, N_CTX, H, D_HEAD = q.shape BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64 block_M = 64
block_N = 64 if D_HEAD <= 128 else 32 block_N = 64 if D_HEAD <= 128 else 32
mod = cached(flashattn_fwd, [3, 4], BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) mod = cached(flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N), [3, 4])
o, lse = mod(q, k, v) o, lse = mod(q, k, v)
ctx.save_for_backward(q, k, v, o, lse) ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal ctx.causal = causal
...@@ -254,11 +254,11 @@ class _attention(torch.autograd.Function): ...@@ -254,11 +254,11 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128 block_M = 128
block_N = 128 if D_HEAD <= 64 else 32 block_N = 128 if D_HEAD <= 64 else 32
mod_prep = cached(flashattn_bwd_preprocess, [2], BATCH, H, N_CTX, D_HEAD) mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2])
mod_post = cached(flashattn_bwd_postprocess, [1], BATCH, H, N_CTX, D_HEAD) mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1])
delta = mod_prep(o, do) delta = mod_prep(o, do)
mod = cached(flashattn_bwd, [6, 7, 8], BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, mod = cached(
block_N) flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N), [6, 7, 8])
dq, dk, dv = mod(q, k, v, do, lse, delta) dq, dk, dv = mod(q, k, v, do, lse, delta)
dq = mod_post(dq) dq = mod_post(dq)
return dq, dk, dv, None return dq, dk, dv, None
......
"""The cache utils""" """The cache utils with class and database persistence - Init file"""
from tilelang import compile from typing import List, Union, Literal, Optional
from tilelang.jit import JITKernel
from typing import Callable, List, Union
from tvm.target import Target from tvm.target import Target
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from .kernel_cache import KernelCache
# Dictionary to store cached kernels # Create singleton instance of KernelCache
_cached = {} _kernel_cache_instance = KernelCache()
def cached( def cached(
func: Callable, func: PrimFunc = None,
out_idx: List[int] = None, out_idx: List[int] = None,
*args, *args,
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Cache and reuse compiled kernels to avoid redundant compilation. Caches and reuses compiled kerne(ls (using KernelCache class).
Args:
func: Function to be compiled or a PrimFunc that's already prepared
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target_host: Host target for compilation
*args: Arguments passed to func when calling it
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
""" """
global _cached return _kernel_cache_instance.cached(
# Create a unique key based on the function, output indices and arguments func,
key = (func, tuple(out_idx), *args) out_idx,
*args,
# Return cached kernel if available target=target,
if key not in _cached: target_host=target_host,
# Handle both PrimFunc objects and callable functions execution_backend=execution_backend,
program = func if isinstance(func, PrimFunc) else func(*args) verbose=verbose,
pass_configs=pass_configs,
# Compile the program to a kernel )
kernel = compile(program, out_idx=out_idx, target=target, target_host=target_host)
# Store in cache for future use
_cached[key] = kernel
return _cached[key]
def clear_cache(): def clear_cache():
""" """
Clear the entire kernel cache. Clears the entire kernel cache (using KernelCache class).
This function resets the internal cache dictionary that stores compiled kernels.
Use this when you want to free memory or ensure fresh compilation
of kernels in a new context.
""" """
global _cached _kernel_cache_instance.clear_cache()
_cached = {}
"""The cache utils with class and database persistence - KernelCache Class"""
import os
import json
import shutil
from hashlib import sha256
from typing import Callable, List, Literal, Union
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
import threading
import cloudpickle
import logging
from tilelang.env import TILELANG_CACHE_DIR # noqa: F401
class KernelCache:
"""
Caches compiled kernels using a class and database persistence to avoid redundant compilation.
"""
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
def __new__(cls, cache_dir=TILELANG_CACHE_DIR):
"""Singleton pattern to ensure only one KernelCache instance"""
with cls._lock:
if cls._instance is None:
cls._instance = super(KernelCache, cls).__new__(cls)
cls._instance._cache = {} # In-memory cache
cls._instance.cache_dir = cache_dir # Cache directory
os.makedirs(cls._instance.cache_dir, exist_ok=True) # Ensure cache directory exists
cls._instance.logger = logging.getLogger(__name__) # Initialize logger
cls._instance.logger.setLevel(
logging.ERROR) # Set default logging level to ERROR, can be adjusted
return cls._instance
def _generate_key(self, func: Callable, out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"], args,
target: Union[str, Target], target_host: Union[str, Target]) -> str:
"""
Generates a unique cache key.
"""
func_binary = cloudpickle.dumps(func)
key_data = {
"func": sha256(func_binary).hexdigest(), # Use SHA256 to generate hash key
"out_idx": tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx],
"args_repr": tuple(
repr(arg) for arg in args
), # Use repr to serialize arguments, may need more robust serialization
"target": str(target),
"target_host": str(target_host) if target_host else None,
"execution_backend": execution_backend,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
def cached(
self,
func: PrimFunc = None,
out_idx: List[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
Args:
func: Function to be compiled or a prepared PrimFunc
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target_host: Host target platform
*args: Arguments passed to func
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
key = self._generate_key(func, out_idx, execution_backend, args, target, target_host)
with self._lock: # Thread-safe access to cache
if key in self._cache:
return self._cache[key]
# Attempt to load from disk
kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
execution_backend, pass_configs, func)
if kernel:
self._cache[key] = kernel # Load to in-memory cache
return kernel
# Compile kernel if cache miss
kernel = JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
self._cache[key] = kernel # Store in in-memory cache
self._save_kernel_to_disk(key, kernel, func)
return kernel
def clear_cache(self):
"""
Clears the entire kernel cache, including both in-memory and disk cache.
"""
with self._lock: # Thread-safe operation
self._cache.clear() # Clear in-memory cache
self._clear_disk_cache() # Clear disk cache
def _get_cache_path(self, key: str) -> str:
"""
Gets the cache file path for a given key.
"""
return os.path.join(self.cache_dir, key)
def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None):
"""
Saves the compiled kernel to disk.
"""
cache_path = self._get_cache_path(key)
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save rt_mod as a str
try:
artifact_path = os.path.join(cache_path, "tvm_tmp_mod.txt")
with open(artifact_path, "w") as f:
f.write(kernel.rt_mod.imported_modules[0].get_source())
except Exception as e:
self.logger.error(f"Error saving kernel module to disk: {e}")
try:
dump_path = os.path.join(cache_path, "tvm_params.pkl")
with open(dump_path, "wb") as f:
cloudpickle.dump(kernel.params, f)
except Exception as e:
self.logger.error(f"Error saving kernel parameters to disk: {e}")
def _load_kernel_from_disk(self,
key: str,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None) -> JITKernel:
"""
Loads kernel from disk.
"""
cache_path = self._get_cache_path(key)
if not os.path.exists(cache_path):
return None
rt_module = None
rt_params = None
try:
artifact_path = os.path.join(cache_path, "tvm_tmp_mod.txt")
with open(artifact_path, "r") as f:
rt_module = f.read()
except Exception as e:
self.logger.error(f"Error loading kernel module from disk: {e}")
try:
dump_path = os.path.join(cache_path, "tvm_params.pkl")
with open(dump_path, "rb") as f:
rt_params = cloudpickle.load(f)
except Exception as e:
self.logger.error(f"Error loading kernel parameters from disk: {e}")
if rt_module and rt_params:
return JITKernel(
rt_module_src=rt_module,
rt_params=rt_params,
execution_backend=execution_backend,
target=target,
target_host=target_host,
out_idx=out_idx,
pass_configs=pass_configs,
func=func,
)
else:
return None
def _clear_disk_cache(self):
"""
Clears the cache directory on disk.
"""
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir) # Delete entire cache directory
os.makedirs(self.cache_dir, exist_ok=True) # Re-create cache directory
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
...@@ -40,6 +40,9 @@ TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None) ...@@ -40,6 +40,9 @@ TVM_LIBRARY_PATH: str = os.environ.get("TVM_LIBRARY_PATH", None)
TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None) TILELANG_TEMPLATE_PATH: str = os.environ.get("TL_TEMPLATE_PATH", None)
TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0] TILELANG_PACKAGE_PATH: str = pathlib.Path(__file__).resolve().parents[0]
TILELANG_CACHE_DIR: str = os.environ.get("TILELANG_CACHE_DIR",
os.path.expanduser("~/.tilelang/cache"))
# SETUP ENVIRONMENT VARIABLES # SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend." ", which may lead to compilation bugs when utilize tilelang backend."
...@@ -115,4 +118,5 @@ __all__ = [ ...@@ -115,4 +118,5 @@ __all__ = [
"TVM_LIBRARY_PATH", "TVM_LIBRARY_PATH",
"TILELANG_TEMPLATE_PATH", "TILELANG_TEMPLATE_PATH",
"CUDA_HOME", "CUDA_HOME",
"TILELANG_CACHE_DIR",
] ]
...@@ -13,6 +13,7 @@ from tvm.target import Target ...@@ -13,6 +13,7 @@ from tvm.target import Target
from tilelang.jit.adapter import BaseKernelAdapter from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.jit.kernel import JITKernel from tilelang.jit.kernel import JITKernel
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.cache import cached
from logging import getLogger from logging import getLogger
logger = getLogger(__name__) logger = getLogger(__name__)
...@@ -86,8 +87,7 @@ def jit( ...@@ -86,8 +87,7 @@ def jit(
""" """
if verbose: if verbose:
logger.info(f"Compiling TileLang function:\n{tilelang_func}") logger.info(f"Compiling TileLang function:\n{tilelang_func}")
return compile(
return JITKernel(
tilelang_func, tilelang_func,
target=target, target=target,
verbose=verbose, verbose=verbose,
...@@ -119,12 +119,12 @@ def compile( ...@@ -119,12 +119,12 @@ def compile(
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
""" """
return JITKernel( return cached(
func, func=func,
out_idx=out_idx, out_idx=out_idx,
execution_backend=execution_backend, execution_backend=execution_backend,
target=target, target=target,
target_host=target_host, target_host=target_host,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
\ No newline at end of file
...@@ -99,6 +99,58 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -99,6 +99,58 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self._post_init() self._post_init()
@classmethod
def from_database(cls,
params: List[TensorType],
result_idx: List[int],
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
adapter.ir_module = func_or_mod
# Cache parameter information during initialization
adapter.param_dtypes = [param.dtype for param in params]
adapter.param_shapes = []
for param in params:
native_shape = []
for dim in param.shape:
if isinstance(dim, tir.IntImm):
native_shape.append(int(dim))
elif isinstance(dim, tir.Var):
native_shape.append(dim) # Keep tir.Var for dynamic dimensions
else:
native_shape.append(dim)
adapter.param_shapes.append(native_shape)
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
adapter.lib.init()
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self): def _process_dynamic_symbolic(self):
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
......
...@@ -7,6 +7,7 @@ from tilelang import tvm as tvm ...@@ -7,6 +7,7 @@ from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tilelang.engine.param import KernelParam from tilelang.engine.param import KernelParam
from tvm import tir from tvm import tir
from tvm.relay import TensorType
from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target
...@@ -209,6 +210,60 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -209,6 +210,60 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map) self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self._post_init() self._post_init()
@classmethod
def from_database(cls,
rt_mod_src: str,
params: List[TensorType],
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
adapter.kernel_global_source = rt_mod_src
if isinstance(func_or_mod, tir.PrimFunc):
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
else:
adapter.ir_module = func_or_mod
target = determine_target(target, return_object=True)
adapter.target = Target.canon_target(determine_target(target))
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.buffer_dtype_map = adapter._process_buffer_dtype()
adapter.static_shape_map = adapter._process_static_shape()
adapter.buffer_device_map = adapter._process_buffer_device()
adapter.verbose = verbose
adapter.wrapper = TLWrapper(adapter.target)
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.wrapper.assign_optimized_module(adapter.ir_module)
adapter.wrapper.assign_pass_configs(pass_configs)
adapter.wrapped_source = adapter.wrapper.wrap(adapter.get_kernel_source(kernel_only=True))
adapter.lib_generator.update_lib_code(adapter.wrapped_source)
adapter.lib_generator.compile_lib()
adapter.lib = adapter.lib_generator.load_lib()
try:
adapter.lib.init()
except Exception as e:
raise Exception(
f"Failed to initialize the compiled library for {adapter.target}: {e}") from e
adapter.cython_wrapper = CythonKernelWrapper(adapter.result_idx, adapter.params,
adapter.lib)
adapter.cython_wrapper.set_dynamic_symbolic_map(adapter.dynamic_symbolic_map)
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map)
adapter._post_init()
return adapter
def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function. """Extract information about dynamic shapes from the TIR function.
......
...@@ -72,7 +72,7 @@ cdef class CythonKernelWrapper: ...@@ -72,7 +72,7 @@ cdef class CythonKernelWrapper:
) )
# Use current CUDA stream if none specified # Use current CUDA stream if none specified
if stream == -1: if stream == -1:
stream = torch.cuda.current_stream().cuda_stream stream = torch.cuda.current_stream().cuda_stream
cdef int ins_idx = 0 cdef int ins_idx = 0
...@@ -86,8 +86,10 @@ cdef class CythonKernelWrapper: ...@@ -86,8 +86,10 @@ cdef class CythonKernelWrapper:
# Now working with native Python list, no FFI calls needed # Now working with native Python list, no FFI calls needed
for s in self.param_shapes[i]: for s in self.param_shapes[i]:
if isinstance(s, tir.Var): if isinstance(s, tir.Var):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s] for key in self.dynamic_symbolic_map:
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx]) if(str(s) == str(key)):
ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
else: # Already converted to Python int during initialization else: # Already converted to Python int during initialization
shape.append(s) shape.append(s)
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device() device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
......
...@@ -37,6 +37,8 @@ class JITKernel(object): ...@@ -37,6 +37,8 @@ class JITKernel(object):
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None, pass_configs: Optional[Dict[str, Any]] = None,
rt_module_src: Optional[str] = None,
rt_params: dict = None,
): ):
""" """
Initializes a TorchFunction instance. Initializes a TorchFunction instance.
...@@ -61,7 +63,6 @@ class JITKernel(object): ...@@ -61,7 +63,6 @@ class JITKernel(object):
"tir.disable_vectorize": bool, default: False "tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False "tl.disable_tma_lower": bool, default: False
""" """
self.func = func
self.out_idx = out_idx self.out_idx = out_idx
self.execution_backend = execution_backend self.execution_backend = execution_backend
self.target = target self.target = target
...@@ -72,6 +73,42 @@ class JITKernel(object): ...@@ -72,6 +73,42 @@ class JITKernel(object):
pass_configs = {} pass_configs = {}
self.pass_configs = pass_configs self.pass_configs = pass_configs
if rt_module_src and rt_params:
self.rt_mod = None
self.params = rt_params
adapter = None
# Create an adapter based on the specified execution backend.
if execution_backend == "dlpack":
# assert dlpack not supported
raise ValueError(f"Invalid execution backend: {execution_backend}")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter.from_database(
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
kernel_global_source=rt_module_src,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter.from_database(
rt_mod_src=rt_module_src,
params=self.params,
result_idx=out_idx,
target=target,
func_or_mod=func,
verbose=verbose,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}")
self.adapter = adapter
self.torch_function = adapter.func
return
# If the target is specified as a string, validate it and convert it to a TVM Target. # If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str): if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}" assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
...@@ -234,3 +271,22 @@ class JITKernel(object): ...@@ -234,3 +271,22 @@ class JITKernel(object):
def run_once(self, func: Optional[Callable] = None) -> None: def run_once(self, func: Optional[Callable] = None) -> None:
return self.get_profiler().run_once(func) return self.get_profiler().run_once(func)
def export_library(self, kernel_file: str) -> None:
"""
Exports the compiled kernel function to a shared library file.
Parameters
----------
kernel_file : str
The path to the shared library file to create.
"""
# rt_module: tvm.runtime.Module = None
# rt_params: dict = None
# adapter: BaseKernelAdapter = None
# torch_function: Callable = None
# rt_module: use export_library to export
# rt_params: use cloudpickle to serialize
# Export the compiled kernel function to a shared library file.
self.rt_module.export_library(kernel_file)
...@@ -83,5 +83,4 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", ...@@ -83,5 +83,4 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
if return_object: if return_object:
return Target(return_var) return Target(return_var)
return return_var return return_var
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment