"vscode:/vscode.git/clone" did not exist on "57a8ccf3bafb87e40f62a88d927fcbd01de7eb4c"
Commit f4bb9f6c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] FIx autotuning params (#585)

* [Enhancement] Update AutoTuner and Profiler for improved kernel handling and output validation

- Modified AutoTuner to store cache in a dedicated "autotuner" directory.
- Enhanced kernel source code saving logic in AutotuneResult and AutoTunerCache to check for None before writing.
- Updated Profiler to handle None outputs gracefully during tensor comparisons, improving robustness in output validation.

* lint fix

* [Enhancement] Improve error handling and documentation in AutoTuner

- Added traceback logging for exceptions during configuration testing to enhance debugging.
- Expanded the AutoTuner class docstring to include detailed descriptions of new parameters for input tensor generation and validation, improving clarity for users.
parent 6cede73d
...@@ -22,6 +22,7 @@ import signal ...@@ -22,6 +22,7 @@ import signal
import json import json
import hashlib import hashlib
import threading import threading
import traceback
from pathlib import Path from pathlib import Path
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
...@@ -103,7 +104,7 @@ class AutoTuner: ...@@ -103,7 +104,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -352,6 +353,7 @@ class AutoTuner: ...@@ -352,6 +353,7 @@ class AutoTuner:
max_mismatched_ratio=max_mismatched_ratio) max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench( latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply() self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(
...@@ -417,15 +419,13 @@ class AutoTuner: ...@@ -417,15 +419,13 @@ class AutoTuner:
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
) )
continue continue
except Exception as e: except Exception:
logger.info( logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details" f"An error occurred while testing config {config}, checkout autotuner.log for more details"
) )
logger.debug(f"Error: {e}") logger.debug(f"Error: {traceback.format_exc()}")
continue continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
if latency < best_latency: if latency < best_latency:
best_latency = latency best_latency = latency
best_config = config best_config = config
...@@ -515,13 +515,37 @@ class _AutoTunerImplementation: ...@@ -515,13 +515,37 @@ class _AutoTunerImplementation:
warmup: Number of warmup iterations before timing. warmup: Number of warmup iterations before timing.
rep: Number of repetitions for timing measurements. rep: Number of repetitions for timing measurements.
timeout: Maximum time (in seconds) allowed for each configuration. timeout: Maximum time (in seconds) allowed for each configuration.
supply_type: Strategy for generating input tensors (random/zeros/etc)
ref_prog: Reference implementation for validation
supply_prog: Custom function to provide input tensors
rtol: Relative tolerance for numerical validation
atol: Absolute tolerance for numerical validation
max_mismatched_ratio: Allowed percentage of mismatched values
skip_check: Bypass validation against reference implementation
manual_check_prog: Custom validation function
cache_input_tensors: Reuse input tensors across trials
""" """
self.configs = configs # Configuration and benchmarking parameters
self.warmup = warmup self.configs = configs # Search space of tuning configurations
self.rep = rep self.warmup = warmup # Warmup iterations for stable measurements
self.timeout = timeout self.rep = rep # Measurement repetitions for statistics
self.timeout = timeout # Per-configuration timeout threshold
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {}
# Tensor handling and validation setup
self.supply_type = supply_type # Input tensor generation strategy
self.ref_prog = ref_prog # Ground truth implementation
self.supply_prog = supply_prog # Custom input data provider
self.rtol = rtol # Relative error tolerance
self.atol = atol # Absolute error tolerance
self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch
# Validation control flags
self.skip_check = skip_check # Bypass accuracy verification
self.manual_check_prog = manual_check_prog # Custom validation
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return. # This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it. # this is for linting, please do not remove it.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment