Commit ae386a7b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Add `__tune_params` into key hash for autotuning (#565)

* [Enhancement] Update AutoTuner and Profiler for improved kernel handling and output validation

- Modified AutoTuner to store cache in a dedicated "autotuner" directory.
- Enhanced kernel source code saving logic in AutotuneResult and AutoTunerCache to check for None before writing.
- Updated Profiler to handle None outputs gracefully during tensor comparisons, improving robustness in output validation.

* lint fix
parent 59172ff6
...@@ -101,7 +101,7 @@ class AutoTuner: ...@@ -101,7 +101,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -350,6 +350,7 @@ class AutoTuner: ...@@ -350,6 +350,7 @@ class AutoTuner:
max_mismatched_ratio=max_mismatched_ratio) max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench( latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply() self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(
...@@ -422,8 +423,6 @@ class AutoTuner: ...@@ -422,8 +423,6 @@ class AutoTuner:
logger.debug(f"Error: {e}") logger.debug(f"Error: {e}")
continue continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
if latency < best_latency: if latency < best_latency:
best_latency = latency best_latency = latency
best_config = config best_config = config
......
...@@ -170,6 +170,7 @@ class AutotuneResult: ...@@ -170,6 +170,7 @@ class AutotuneResult:
# Save kernel source code # Save kernel source code
try: try:
kernel_path = os.path.join(cache_path, KERNEL_PATH) kernel_path = os.path.join(cache_path, KERNEL_PATH)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f: with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source) f.write(kernel.artifact.kernel_source)
except Exception as e: except Exception as e:
......
...@@ -249,6 +249,7 @@ class KernelCache: ...@@ -249,6 +249,7 @@ class KernelCache:
# Save kernel source code # Save kernel source code
try: try:
kernel_path = os.path.join(cache_path, KERNEL_PATH) kernel_path = os.path.join(cache_path, KERNEL_PATH)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f: with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source) f.write(kernel.artifact.kernel_source)
except Exception as e: except Exception as e:
......
...@@ -245,6 +245,7 @@ class AutoTunerCache: ...@@ -245,6 +245,7 @@ class AutoTunerCache:
# Save kernel source code # Save kernel source code
try: try:
kernel_path = os.path.join(cache_path, KERNEL_PATH) kernel_path = os.path.join(cache_path, KERNEL_PATH)
if kernel.artifact.kernel_source is not None:
with open(kernel_path, "w") as f: with open(kernel_path, "w") as f:
f.write(kernel.artifact.kernel_source) f.write(kernel.artifact.kernel_source)
except Exception as e: except Exception as e:
......
...@@ -181,7 +181,8 @@ class _JitImplementation: ...@@ -181,7 +181,8 @@ class _JitImplementation:
key_args_tuple = args key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items())) key_kwargs_tuple = tuple(sorted(kwargs.items()))
key = (key_args_tuple, key_kwargs_tuple) tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
if key not in self._kernel_cache: if key not in self._kernel_cache:
# Ensure 'func' (the original user function) is used correctly # Ensure 'func' (the original user function) is used correctly
......
...@@ -98,11 +98,15 @@ class Profiler: ...@@ -98,11 +98,15 @@ class Profiler:
if isinstance(lib_outs, torch.Tensor): if isinstance(lib_outs, torch.Tensor):
lib_outs = [lib_outs] lib_outs = [lib_outs]
elif isinstance(lib_outs, tuple):
lib_outs = list(lib_outs)
elif lib_outs is None: elif lib_outs is None:
lib_outs = [] lib_outs = []
if isinstance(ref_outs, torch.Tensor): if isinstance(ref_outs, torch.Tensor):
ref_outs = [ref_outs] ref_outs = [ref_outs]
elif isinstance(ref_outs, tuple):
ref_outs = list(ref_outs)
elif ref_outs is None: elif ref_outs is None:
ref_outs = [] ref_outs = []
...@@ -119,6 +123,9 @@ class Profiler: ...@@ -119,6 +123,9 @@ class Profiler:
# percentage_not_close = (num_not_close / total_elements) * 100 # percentage_not_close = (num_not_close / total_elements) * 100
# print(f"{percentage_not_close:.2f}% of the elements are not close.") # print(f"{percentage_not_close:.2f}% of the elements are not close.")
# print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}")
if lhs is not None and rhs is not None:
# in case of numsplit template, the ref output may be None
# which means the value is invalid, so we skip the comparison
torch_assert_close( torch_assert_close(
lhs, lhs,
rhs, rhs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment