Commit ae386a7b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Add `__tune_params` into key hash for autotuning (#565)

* [Enhancement] Update AutoTuner and Profiler for improved kernel handling and output validation

- Modified AutoTuner to store cache in a dedicated "autotuner" directory.
- Enhanced kernel source code saving logic in AutotuneResult and AutoTunerCache to check for None before writing.
- Updated Profiler to handle None outputs gracefully during tensor comparisons, improving robustness in output validation.

* lint fix
parent 59172ff6
......@@ -101,7 +101,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs):
self.fn = fn
......@@ -350,6 +350,7 @@ class AutoTuner:
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
......@@ -422,8 +423,6 @@ class AutoTuner:
logger.debug(f"Error: {e}")
continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
if latency < best_latency:
best_latency = latency
best_config = config
......
......@@ -170,8 +170,9 @@ class AutotuneResult:
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
logger.error(f"Error saving kernel source code to disk: {e}")
......
......@@ -249,8 +249,9 @@ class KernelCache:
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
......
......@@ -245,8 +245,9 @@ class AutoTunerCache:
# Save kernel source code
try:
kernel_path = os.path.join(cache_path, KERNEL_PATH)
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source)
except Exception as e:
self.logger.error(f"Error saving kernel source code to disk: {e}")
......
......@@ -181,7 +181,8 @@ class _JitImplementation:
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple)
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly
......
......@@ -98,11 +98,15 @@ class Profiler:
if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs]
elif isinstance(lib_outs, tuple):
lib_outs = list(lib_outs)
elif lib_outs is None:
lib_outs = []
if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs]
elif isinstance(ref_outs, tuple):
ref_outs = list(ref_outs)
elif ref_outs is None:
ref_outs = []
......@@ -119,15 +123,18 @@ class Profiler:
# percentage_not_close = (num_not_close / total_elements) * 100
# print(f"{percentage_not_close:.2f}% of the elements are not close.")
# print(f"Total elements: {total_elements}, Not close elements: {num_not_close}")
torch_assert_close(
lhs,
rhs,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
base_name="tilelang",
ref_name="ref",
)
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
torch_assert_close(
lhs,
rhs,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio,
base_name="tilelang",
ref_name="ref",
)
def manual_assert_close(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment