"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "8acddd0c0ff611ee1ebc7cc0dce532d8c0f332d2"
Commit 622fa042 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Simplify kernel source extraction in JIT adapters (#230)

* [Enhancement] Simplify kernel source extraction in JIT adapters

- Updated CtypesKernelAdapter and CythonKernelAdapter to include kernel_global_source for improved source code retrieval.
- Modified the _legalize_result_idx method in BaseKernelAdapter to accept Optional[List[int]] for better flexibility.
- Added comments to clarify the purpose of kernel_global_source in both adapters, enhancing code readability and maintainability.

* [Refactor] Update parameter type in CythonKernelAdapter constructor

- Changed the parameter type from List[TensorType] to List[KernelParam] in the CythonKernelAdapter's __init__ method to enhance type consistency and align with recent refactoring efforts across modules.
parent 7ae35298
...@@ -17,7 +17,7 @@ class BaseKernelAdapter(ABC): ...@@ -17,7 +17,7 @@ class BaseKernelAdapter(ABC):
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self._post_init() self._post_init()
def _legalize_result_idx(self, result_idx: List[int]) -> List[int]: def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]:
params = self.params params = self.params
# result_idx is a list of indices of the output tensors # result_idx is a list of indices of the output tensors
if result_idx is None: if result_idx is None:
......
...@@ -25,7 +25,10 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -25,7 +25,10 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target = "cuda" target = "cuda"
ir_module = None ir_module: Optional[tvm.IRModule] = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
...@@ -38,11 +41,11 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -38,11 +41,11 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_shapes: Optional[List[List]] = None # Cache for parameter shapes param_shapes: Optional[List[List]] = None # Cache for parameter shapes
def __init__(self, def __init__(self,
rt_mod,
params: List[TensorType], params: List[TensorType],
result_idx: List[int], result_idx: List[int],
target, target: str,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
...@@ -55,9 +58,9 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -55,9 +58,9 @@ class CtypesKernelAdapter(BaseKernelAdapter):
func_or_mod: TIR function or module to be compiled func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging verbose: Enable verbose logging
""" """
self.mod = rt_mod
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -215,7 +218,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -215,7 +218,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
if kernel_only: if kernel_only:
return self.mod.imported_modules[0].get_source() return self.kernel_global_source
else: else:
assert self.wrapped_source is not None, "Wrapped source is not available" assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source return self.wrapped_source
...@@ -131,6 +131,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -131,6 +131,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information # Class attributes to store compiled kernel information
target: Union[str, Target] = "cuda" target: Union[str, Target] = "cuda"
ir_module: Optional[tvm.IRModule] = None ir_module: Optional[tvm.IRModule] = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
kernel_global_source: Optional[str] = None
lib: Optional[ctypes.CDLL] = None # Compiled library handle lib: Optional[ctypes.CDLL] = None # Compiled library handle
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
...@@ -148,11 +151,11 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -148,11 +151,11 @@ class CythonKernelAdapter(BaseKernelAdapter):
pass_configs: Optional[Dict[str, Any]] = None pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, def __init__(self,
rt_mod,
params: List[KernelParam], params: List[KernelParam],
result_idx: List[int], result_idx: List[int],
target: Union[str, Target], target: Union[str, Target],
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
kernel_global_source: str,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None): pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
...@@ -165,9 +168,9 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -165,9 +168,9 @@ class CythonKernelAdapter(BaseKernelAdapter):
func_or_mod: TIR function or module to be compiled func_or_mod: TIR function or module to be compiled
verbose: Enable verbose logging verbose: Enable verbose logging
""" """
self.mod = rt_mod
self.params = params self.params = params
self.result_idx = self._legalize_result_idx(result_idx) self.result_idx = self._legalize_result_idx(result_idx)
self.kernel_global_source = kernel_global_source
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) self.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
...@@ -332,7 +335,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -332,7 +335,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
def get_kernel_source(self, kernel_only: bool = False): def get_kernel_source(self, kernel_only: bool = False):
"""Returns the source code of the compiled kernel.""" """Returns the source code of the compiled kernel."""
if kernel_only: if kernel_only:
return self.mod.imported_modules[0].get_source() return self.kernel_global_source
else: else:
assert self.wrapped_source is not None, "Wrapped source is not available" assert self.wrapped_source is not None, "Wrapped source is not available"
return self.wrapped_source return self.wrapped_source
...@@ -147,22 +147,26 @@ class JITKernel(object): ...@@ -147,22 +147,26 @@ class JITKernel(object):
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx) adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "ctypes": elif execution_backend == "ctypes":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CtypesKernelAdapter( adapter = CtypesKernelAdapter(
rt_mod,
params=params, params=params,
result_idx=out_idx, result_idx=out_idx,
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
elif execution_backend == "cython": elif execution_backend == "cython":
# TODO(Lei): global source extraction can be simplified
kernel_global_source = rt_mod.imported_modules[0].get_source()
adapter = CythonKernelAdapter( adapter = CythonKernelAdapter(
rt_mod,
params=params, params=params,
result_idx=out_idx, result_idx=out_idx,
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
kernel_global_source=kernel_global_source,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs, pass_configs=pass_configs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment