Commit 5fbfb80b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Autotune] Remove the out_idx argument from the autotune cache (#553)



* [Enhancement] Update AutoTuner and JIT compilation arguments

* Added functionality to return compile arguments in the JIT implementation, enhancing the autotuner's caching capabilities.
* Modified `CompileArgs` and `AutotuneResult` classes to support optional `out_idx` parameter, improving flexibility in compile argument handling.
* Refactored the `_AutoTunerImplementation` to utilize the new compile arguments, ensuring better integration and performance during tuning processes.

* Update tilelang/autotuner/param.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* remove redundant comments

* Update tilelang/jit/__init__.py
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 444b7c4e
......@@ -550,6 +550,8 @@ class _AutoTunerImplementation:
def jit_compile(**config_arg):
return fn(*args, **kwargs, __tune_params=config_arg)
compile_arguments = fn(__return_compile_arguments=True)
autotuner = AutoTuner(
fn, configs=configs).set_profile_args(
supply_type=self.supply_type,
......@@ -561,13 +563,22 @@ class _AutoTunerImplementation:
skip_check=self.skip_check,
manual_check_prog=self.manual_check_prog,
cache_input_tensors=self.cache_input_tensors,
).set_compile_args(
out_idx=compile_arguments['out_idx'],
execution_backend=compile_arguments['execution_backend'],
target=compile_arguments['target'],
target_host=compile_arguments['target_host'],
verbose=compile_arguments['verbose'],
pass_configs=compile_arguments['pass_configs'],
)
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.run = partial(autotuner.run, warmup, rep, timeout)
artifact = autotuner.run()
self._tuner_cache[key] = artifact.kernel
return self._tuner_cache[key]
......
......@@ -47,7 +47,7 @@ class CompileArgs:
"tl.disable_safe_memory_legalize": bool, default: False
"""
out_idx: Union[List[int], int] = -1
out_idx: Optional[Union[List[int], int]] = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
......@@ -65,8 +65,6 @@ class CompileArgs:
def __hash__(self):
data = {
"out_idx":
self.out_idx,
"execution_backend":
self.execution_backend,
"target":
......@@ -206,7 +204,7 @@ class AutotuneResult:
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
out_idx: Optional[Union[List[int], int]] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
......
......@@ -80,7 +80,7 @@ def compile(
class _JitImplementation:
out_idx: Any
out_idx: Optional[Union[List[int], int]]
target: Union[str, Target]
target_host: Union[str, Target]
execution_backend: Literal["dlpack", "ctypes", "cython"]
......@@ -166,6 +166,18 @@ class _JitImplementation:
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments:
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
'target': self.target,
'target_host': self.target_host,
'verbose': self.verbose,
'pass_configs': self.pass_configs,
}
return compile_args
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment