Commit 8c8d8ca2 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Refactor `jit` to `_JitImplementation` to support `@tilelang.jit` (#502)

* [Refactor] Rename `jit` class to `_JitImplementation` and improve debug path handling

* Refactored the `jit` class to `_JitImplementation` for clarity and encapsulation.
* Enhanced handling of `debug_root_path` to ensure it is correctly set as an absolute path when provided.
* Updated the public `jit` function to serve as a decorator interface, allowing for both default and configured usage.
* Added validation to ensure input tensors are contiguous in the Cython wrapper, improving error handling.

* [Refactor] Improve formatting and handling in `_JitImplementation` and `jit` function

* Refactored the `_JitImplementation` class to enhance readability by adjusting comment formatting and consolidating conditions for setting `debug_root_path`.
* Updated the `jit` function signature for better alignment and clarity in parameter definitions.
* Ensured consistent spacing and comments throughout the code for improved maintainability.

* [Refactor] Update GEMM test parameters for performance optimization

* Set num_stages to 0 and adjusted matrix dimensions in the GEMM test function to enhance performance and consistency across tests in test_tilelang_jit_gemm.py.
* Reduced the number of threads used in the test to align with the updated configuration, improving overall test efficiency.

* [Refactor] Enhance buffer error logging in layout inference

* Updated the warning message in layout inference to provide clearer context when a buffer cannot be inferred due to its absence in the use list. This change improves the clarity of error reporting during layout inference operations.
* Refactored tensor handling in the Cython wrapper to ensure input tensors are checked for contiguity before processing, enhancing error handling and robustness in tensor management.

* bugfix
parent 67bd9f69
......@@ -298,9 +298,10 @@ public:
// Check if buffer exists in use_list_
if (!use_list_.count(buffer)) {
LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
<< "Potential mismatch between inference updates and "
<< "use_list_.";
LOG(WARNING) << "Layout inference failed for buffer " << buffer
<< ". "
<< "The buffer cannot be inferred with current layout "
"inference rules.";
continue;
}
......
......@@ -67,7 +67,7 @@ def run_gemm_kernel_jit(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=0,
num_threads=128,
):
matmul_kernel = matmul_kernel_jit(
......@@ -117,9 +117,9 @@ def test_gemm_f16f16f16_nn_kernel_jit():
"float16",
"float16",
128,
256,
128,
32,
2,
0,
)
......
......@@ -112,7 +112,7 @@ _P = ParamSpec("_P")
_RProg = TypeVar("_RProg", bound=Program)
class jit:
class _JitImplementation:
# Overload __init__ to help type checkers understand the effect of return_program
# The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
@overload
......@@ -188,28 +188,23 @@ class jit:
original program's result and the compiled kernel. If False, only the
compiled kernel is returned (default: False).
"""
if debug_root_path is None:
# This logic was previously under 'if debug and debug_root_path is None:'
# Now, if debug_root_path is explicitly None, we don't try to set a default path.
# If a user wants debugging, they must provide a path.
pass
elif not path.isabs(debug_root_path): # If a relative path is given, make it absolute
try:
# This assumes the file is part of a typical project structure
base_path = path.dirname(path.dirname(path.dirname(__file__)))
debug_root_path = path.join(base_path, debug_root_path)
except NameError: # __file__ is not defined (e.g., in a REPL or notebook)
# Fallback to making it absolute based on current working directory if __file__ fails
debug_root_path = path.abspath(debug_root_path)
self.out_idx = out_idx
self.execution_backend = execution_backend
self.target = target
self.target_host = target_host
self.verbose = verbose
self.pass_configs = pass_configs
self.debug_root_path: Optional[str] = debug_root_path
self.return_program: bool = return_program
self.return_program = return_program # Stored from args
# Corrected debug_root_path handling
self.debug_root_path = debug_root_path
if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
try:
base_path = path.dirname(path.dirname(path.dirname(__file__)))
self.debug_root_path = path.join(base_path, self.debug_root_path)
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
# If debug_root_path was None initially, it remains None.
# Type hint the caches
self._program_cache: Dict[tuple, _RProg] = {}
......@@ -217,47 +212,35 @@ class jit:
# Overload __call__ based on the value of self.return_program
# This tells the type checker what the *wrapper* function will return.
# The wrapper will take the same parameters P as the original function.
# Case 1: return_program is True
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
# This signature is chosen by the type checker if self.return_program is True
# (inferred from the __init__ call).
...
# Case 2: return_program is False (or not specified, defaulting to False)
@overload
def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Kernel]:
# This signature is chosen if self.return_program is False.
...
# Actual implementation of __call__
def __call__(
self, func: Union[Callable[_P, _RProg], PrimFunc]
) -> Callable[_P, Any]: # Any for implementation flexibility
self,
func: Callable[_P, _RProg] # func is Union[Callable[_P, _RProg], PrimFunc] in original
) -> Callable[_P, Any]:
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: # Use _P.args and _P.kwargs
# Create a hashable key. args is already a tuple.
# For kwargs, convert to a sorted tuple of items to ensure consistent ordering.
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
# Check if both program and kernel are cached.
# If program is not cached, we'll recompute both.
# (The original check 'key not in self._program_cache or key not in self._kernel_cache'
# implies that if either is missing, both are recomputed and stored.
# A simpler 'key not in self._program_cache' would often suffice if they are always
# added together.)
if key not in self._program_cache: # Assuming if program isn't there, kernel isn't either or needs refresh
if isinstance(func, PrimFunc):
program_result = func
elif isinstance(func, Callable):
program_result = func(*args, **kwargs)
if key not in self._program_cache:
# Ensure 'func' (the original user function) is used correctly
program_result_source = func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(func)}")
raise ValueError(f"Invalid function type: {type(program_result_source)}")
kernel_result = compile(
program_result,
......@@ -269,10 +252,9 @@ class jit:
pass_configs=self.pass_configs,
)
if self.debug_root_path: # Check if a path is provided
func_name = func.__name__
if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
# Ensure the debug directory exists
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
......@@ -280,7 +262,6 @@ class jit:
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result
# Retrieve from cache (even if just populated)
cached_program = self._program_cache[key]
cached_kernel = self._kernel_cache[key]
......@@ -290,3 +271,82 @@ class jit:
return cached_kernel
return wrapper
def jit( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
debug_root_path: Optional[str] = None,
return_program: bool = False):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
This decorator can be used in two ways:
1. Without arguments (e.g., `@tilelang.jit`):
Applies JIT compilation with default settings.
2. With arguments (e.g., `@tilelang.jit(target="cuda", return_program=True)`):
Configures the JIT compilation process with the specified options.
Parameters
----------
func_or_out_idx : Any, optional
If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
If using `@tilelang.jit` directly on a function, this argument is implicitly
the function to be decorated (and `out_idx` will be `None`).
target : Union[str, Target], optional
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
Backend for kernel execution and argument passing. Defaults to "cython".
verbose : bool, optional
Enables verbose logging during compilation. Defaults to False.
pass_configs : Optional[Dict[str, Any]], optional
Configurations for TVM's pass context. Defaults to None.
debug_root_path : Optional[str], optional
Directory to save compiled kernel source for debugging. Defaults to None.
return_program : bool, optional
If True, the decorated function returns a tuple (original program's result, compiled kernel).
Otherwise, only the compiled kernel is returned. Defaults to False.
Returns
-------
Callable
Either a JIT-compiled wrapper around the input function, or a configured decorator
instance that can then be applied to a function.
"""
if callable(func):
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
# Create a default _JitImplementation instance and apply it to the function.
default_decorator = _JitImplementation(
out_idx=out_idx, # Explicitly None for the default case
target=target,
target_host=target_host,
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
return_program=return_program)
return default_decorator(func)
elif isinstance(func, PrimFunc):
raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
else:
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _JitImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator = _JitImplementation(
out_idx=out_idx, # Pass along; could be an actual out_idx or None
target=target,
target_host=target_host,
execution_backend=execution_backend,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
return_program=return_program)
return configured_decorator
......@@ -114,20 +114,23 @@ cdef class CythonKernelWrapper:
# Convert tensor pointers to C void pointers for kernel call
call_args = []
for i in range(len(tensor_list)):
if isinstance(tensor_list[i], torch.Tensor):
call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr()))
elif isinstance(tensor_list[i], int):
tensor = tensor_list[i]
if isinstance(tensor, torch.Tensor):
if not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, int):
# Dynamic symbolics which are passed as integer arguments
if i in self.ptr_map:
call_args.append(ctypes.c_void_p(tensor_list[i]))
call_args.append(ctypes.c_void_p(tensor))
else:
call_args.append(tensor_list[i])
elif isinstance(tensor_list[i], float):
call_args.append(ctypes.c_float(tensor_list[i]))
elif isinstance(tensor_list[i], bool):
call_args.append(ctypes.c_bool(tensor_list[i]))
call_args.append(tensor)
elif isinstance(tensor, float):
call_args.append(ctypes.c_float(tensor))
elif isinstance(tensor, bool):
call_args.append(ctypes.c_bool(tensor))
else:
raise ValueError(f"Unsupported tensor type: {type(tensor_list[i])}")
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
# Check buffer device
# cdef str tensor_list_device_type = tensor_list[0].device.type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment