"driver/conv_driver.cpp" did not exist on "9f46cdf5faebb000ba4c3da33fa8c0bd05fc614d"
Unverified Commit e805f8e5 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[BugFix] Adding extra parameters into autotune hashkey (#1274)

* [BugFix] Adding extra parameters into autotune hashkey

* lint

* None check

* check serializable
parent b1922518
...@@ -235,7 +235,8 @@ class AutoTuner: ...@@ -235,7 +235,8 @@ class AutoTuner:
self._kernel_parameters = k_parameters self._kernel_parameters = k_parameters
self._function_parameters = f_parameters self._function_parameters = f_parameters
def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None: def generate_cache_key(self, parameters: dict[str, Any],
extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process. """Generate a cache key for the auto-tuning process.
""" """
...@@ -261,6 +262,7 @@ class AutoTuner: ...@@ -261,6 +262,7 @@ class AutoTuner:
key_data = { key_data = {
"version": __version__, "version": __version__,
"op_parameters": tuple(op_parameters), "op_parameters": tuple(op_parameters),
"extra_parameters": extra_parameters,
"func_source": func_source, "func_source": func_source,
"configs": self.configs, "configs": self.configs,
"compile_args": hash(self.compile_args), "compile_args": hash(self.compile_args),
...@@ -293,10 +295,28 @@ class AutoTuner: ...@@ -293,10 +295,28 @@ class AutoTuner:
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
parameters = sig.parameters parameters = sig.parameters
# NOTE(chaofan): We need to extract some parameters from the closure.
# Consider the case:
# def gemm(M, N, K):
# def kernel(...)
# If we only extract source, M/N/K will be symbolic and there will be cache problem.
extra_parameters: dict[str, Any] = {}
cells = self.fn.__closure__
var_names = self.fn.__code__.co_freevars
if cells is not None:
assert len(var_names) == len(cells), "Number of free variables does not match"
for var_name, cell in zip(var_names, cells):
if var_name in parameters:
continue
# Cell content must be serializable
assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \
f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}"
extra_parameters[var_name] = cell.cell_contents
if isinstance(self.configs, Callable): if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters) self.configs = self.configs(*self._kernel_parameters)
key = self.generate_cache_key(parameters) key = self.generate_cache_key(parameters, extra_parameters)
with self._lock: with self._lock:
if env.is_cache_enabled(): if env.is_cache_enabled():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment