"testing/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "8edd6941414e112f3fceb56f0dfcfc65c993fc85"
Commit e6f77253 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[AutoTune] Enable config-performance trace (#174)

* Improve Autotuner and CUDA Compatibility for Tensor Core Policies

- Enhance autotuner with robust parallel compilation and error handling
- Add logging for better debugging during configuration compilation
- Support SM90 compute capabilities in TensorCore and matmul analysis policies
- Improve future handling and result tracking in autotuner
- Add more flexible SM version checks for pipeline and async copy stages

* Refactor Autotuner Parallel Compilation with Improved Error Handling

- Enhance tqdm progress bar formatting for concurrent configuration compilation
- Simplify exception handling in parallel compilation process
- Remove unnecessary logging and improve code readability
- Optimize thread pool shutdown and result processing
parent 8344af52
...@@ -14,6 +14,8 @@ import concurrent.futures ...@@ -14,6 +14,8 @@ import concurrent.futures
import os import os
from functools import partial from functools import partial
logger = logging.getLogger(__name__)
logging.basicConfig( logging.basicConfig(
filename='out.log', filename='out.log',
filemode='w', filemode='w',
...@@ -108,7 +110,6 @@ class Autotuner: ...@@ -108,7 +110,6 @@ class Autotuner:
# Parallel compilation # Parallel compilation
config_args = [] config_args = []
jit_contexts = []
for config in self.configs: for config in self.configs:
new_args = [] new_args = []
...@@ -128,39 +129,56 @@ class Autotuner: ...@@ -128,39 +129,56 @@ class Autotuner:
# 90% utilization # 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9)) num_workers = max(1, int(os.cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
results = tqdm(
pool.map( # Submit all compilation jobs
worker, futures = []
config_args, future_to_index = {} # Track which future corresponds to which config
), desc="Compiling configurations") for i, config_arg in enumerate(config_args):
for result in results: future = pool.submit(worker, config_arg)
jit_contexts.append(result) futures.append(future)
future_to_index[future] = i
# Process results with error handling
results_with_configs = []
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Compiling configurations"):
idx = future_to_index[future]
config = config_args[idx]
try:
result = future.result()
results_with_configs.append((result, config))
except Exception:
logger.debug(f"Compilation failed for config {config} at index {idx}")
continue
ref_latency = None ref_latency = None
progress_bar = tqdm(range(len(config_args)), desc="Bench configurations") progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar: for i in progress_bar:
jit_context = jit_contexts[i] jit_context, config = results_with_configs[i]
config = config_args[i]
try: try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution # Use ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(target_fn, jit_context) future = executor.submit(target_fn, jit_context)
latency, ref_latency = future.result(timeout=self.timeout) latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError: except concurrent.futures.TimeoutError:
logging.error(f"Timeout exceeded for config {config}. Skipping this configuration.") logger.debug(f"Timeout exceeded for config {config}. Skipping this configuration.")
continue continue
except Exception as e: except Exception as e:
logging.error(f"An error occurred while testing config {config}: {e}") logger.debug(f"An error occurred while testing config {config}: {e}")
continue continue
logging.info(f"Config {config} latency: {latency}") logging.debug(f"Config {config} latency: {latency} at index {i}")
progress_bar.set_postfix({"best_latency": best_latency})
if latency < best_latency: if latency < best_latency:
best_latency = latency best_latency = latency
best_config = config best_config = config
tqdm.write(f"Tuned Latency {latency} with config {config}")
progress_bar.set_postfix({"best_latency": best_latency})
tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")
pool.shutdown()
return best_latency, best_config, ref_latency return best_latency, best_config, ref_latency
def __call__(self, *args: Any, **kwds: Any) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any:
......
...@@ -575,14 +575,14 @@ def get_tensorized_func_and_tags( ...@@ -575,14 +575,14 @@ def get_tensorized_func_and_tags(
# analysis pipeline stage # analysis pipeline stage
# todo(lei): maybe we can integrate this into policy in the future # todo(lei): maybe we can integrate this into policy in the future
tags["pipeline_stage"] = 1 tags["pipeline_stage"] = 1
if target.kind.name == "cuda" and check_sm_version(target.arch) == 80: if target.kind.name == "cuda" and check_sm_version(target.arch) in {80, 90}:
# enable pipeline stage only for sm_80 devices # enable pipeline stage only for sm_80 devices
tags["pipeline_stage"] = 2 tags["pipeline_stage"] = 2
# analysis async copy # analysis async copy
# todo(lei): maybe we can integrate this into policy in the future # todo(lei): maybe we can integrate this into policy in the future
tags["use_async_copy"] = False tags["use_async_copy"] = False
if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) >= 80: if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) in {80, 90}:
# async copy only works in software pipeline. # async copy only works in software pipeline.
tags["use_async_copy"] = True tags["use_async_copy"] = True
......
...@@ -33,7 +33,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -33,7 +33,7 @@ class TensorCorePolicy(DefaultPolicy):
if pipleline_stage: if pipleline_stage:
self.pipeline_stage = pipleline_stage self.pipeline_stage = pipleline_stage
else: else:
if self.arch.compute_capability == "sm_80": if self.arch.compute_capability in {"sm_80", "sm_90", "sm_90a"}:
self.pipeline_stage = 2 self.pipeline_stage = 2
else: else:
self.pipeline_stage = 1 self.pipeline_stage = 1
...@@ -41,7 +41,7 @@ class TensorCorePolicy(DefaultPolicy): ...@@ -41,7 +41,7 @@ class TensorCorePolicy(DefaultPolicy):
if use_async_copy: if use_async_copy:
self.use_async_copy = use_async_copy self.use_async_copy = use_async_copy
else: else:
if self.arch.compute_capability == "sm_80": if self.arch.compute_capability in {"sm_80", "sm_90", "sm_90a"}:
self.use_async_copy = True self.use_async_copy = True
else: else:
self.use_async_copy = False self.use_async_copy = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment