"vscode:/vscode.git/clone" did not exist on "0d68ddf3275b20b0d12cfd3d0a9f002fecfe001c"
Commit 622fa042 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Simplify kernel source extraction in JIT adapters (#230)

* [Enhancement] Simplify kernel source extraction in JIT adapters

- Updated CtypesKernelAdapter and CythonKernelAdapter to include kernel_global_source for improved source code retrieval.
- Modified the _legalize_result_idx method in BaseKernelAdapter to accept Optional[List[int]] for better flexibility.
- Added comments to clarify the purpose of kernel_global_source in both adapters, enhancing code readability and maintainability.

* [Refactor] Update parameter type in CythonKernelAdapter constructor

- Changed the parameter type from List[TensorType] to List[KernelParam] in the CythonKernelAdapter's __init__ method to enhance type consistency and align with recent refactoring efforts across modules.
parent 7ae35298
......@@ -17,7 +17,7 @@ class BaseKernelAdapter(ABC):
self.result_idx = self._legalize_result_idx(result_idx)
self._post_init()
def _legalize_result_idx(self, result_idx: List[int]) -> List[int]:
def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
......
......@@ -25,7 +25,10 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information
target = "cuda"
ir_module = None
ir_module: Optional[tvm.IRModule] = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
......@@ -38,11 +41,11 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_shapes: Optional[List[List]] = None # Cache for parameter shapes
def __init__(self,
rt_mod,
params: List[TensorType],
result_idx: List[int],
target,
target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: Optional[str] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
......@@ -55,9 +58,9 @@ class CtypesKernelAdapter(BaseKernelAdapter):
func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging
"""
self.mod = rt_mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......@@ -215,7 +218,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
if kernel_only:
return self.mod.imported_modules[0].get_source()
return self.kernel_global_source
else:
assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source
......@@ -131,6 +131,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information
target: Union[str, Target] = "cuda"
ir_module: Optional[tvm.IRModule] = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
......@@ -148,11 +151,11 @@ class CythonKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self,
rt_mod,
params: List[KernelParam],
result_idx: List[int],
target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
......@@ -165,9 +168,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging
"""
self.mod = rt_mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
......@@ -332,7 +335,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel."""
if kernel_only:
return self.mod.imported_modules[0].get_source()
return self.kernel_global_source
else:
assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source
......@@ -147,22 +147,26 @@ class JITKernel(object):
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "ctypes":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CtypesKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CythonKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
verbose=verbose,
pass_configs=pass_configs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment