Commit f4bb9f6c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] FIx autotuning params (#585)

* [Enhancement] Update AutoTuner and Profiler for improved kernel handling and output validation

- Modified AutoTuner to store cache in a dedicated "autotuner" directory.
- Enhanced kernel source code saving logic in AutotuneResult and AutoTunerCache to check for None before writing.
- Updated Profiler to handle None outputs gracefully during tensor comparisons, improving robustness in output validation.

* lint fix

* [Enhancement] Improve error handling and documentation in AutoTuner

- Added traceback logging for exceptions during configuration testing to enhance debugging.
- Expanded the AutoTuner class docstring to include detailed descriptions of new parameters for input tensor generation and validation, improving clarity for users.
parent 6cede73d
......@@ -22,6 +22,7 @@ import signal
import json
import hashlib
import threading
import traceback
from pathlib import Path
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
......@@ -103,7 +104,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR)
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs):
self.fn = fn
......@@ -352,6 +353,7 @@ class AutoTuner:
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
......@@ -417,15 +419,13 @@ class AutoTuner:
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue
except Exception as e:
except Exception:
logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
)
logger.debug(f"Error: {e}")
logger.debug(f"Error: {traceback.format_exc()}")
continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
if latency < best_latency:
best_latency = latency
best_config = config
......@@ -515,13 +515,37 @@ class _AutoTunerImplementation:
warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
"""
self.configs = configs
self.warmup = warmup
self.rep = rep
self.timeout = timeout
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {}
# Configuration and benchmarking parameters
self.configs = configs # Search space of tuning configurations
self.warmup = warmup # Warmup iterations for stable measurements
self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment