Commit 18be9e07 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Carver] Multi-Threads Compilation for Fast Auto Tuning (#156)

* [Carver] Multi-Threads Compilation for Fast Auto Tuning

* Add progress bar for compilation

* lint
parent 782ca9f6
......@@ -167,7 +167,7 @@ def matmul(M, N, K, with_roller):
"enable_rasteration",
],
warmup=3,
rep=5,
rep=20,
)
@jit(
out_idx=[2],
......
......@@ -11,6 +11,8 @@ from tqdm import tqdm
import logging
from dataclasses import dataclass
import concurrent.futures
import os
from functools import partial
logging.basicConfig(
filename='out.log',
......@@ -56,6 +58,10 @@ class Autotuner:
self.jit_input_tensors = None
self.ref_input_tensors = None
def jit_compile(self, args: Any, **kwds: Any) -> JITContext:
jit_context = self.fn(*args, **kwds)
return jit_context
def run(self, *args: Any, **kwds: Any) -> Any:
sig = inspect.signature(self.fn)
bound_args = sig.bind(*args, **kwds)
......@@ -64,9 +70,7 @@ class Autotuner:
best_latency = 1e8
best_config = None
def target_fn(*new_args, **kwds):
jit_context = self.fn(*new_args, **kwds)
def target_fn(jit_context):
# Unpack the context
mod = jit_context.mod
profiler = jit_context.profiler
......@@ -102,8 +106,11 @@ class Autotuner:
return latency, self.ref_latency_cache
progress_bar = tqdm(self.configs, desc="Running configurations")
for config in progress_bar:
# Parallel compilation
config_args = []
jit_contexts = []
for config in self.configs:
new_args = []
for name, value in bound_args.arguments.items():
if name not in self.keys:
......@@ -111,11 +118,33 @@ class Autotuner:
else:
new_args.append(config[name])
new_args = tuple(new_args)
ref_latency = None
config_args.append(new_args)
worker = partial(
self.jit_compile,
**kwds,
)
# 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
results = tqdm(
pool.map(
worker,
config_args,
), desc="Compiling configurations")
for result in results:
jit_contexts.append(result)
ref_latency = None
progress_bar = tqdm(range(len(config_args)), desc="Bench configurations")
for i in progress_bar:
jit_context = jit_contexts[i]
config = config_args[i]
try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(target_fn, *new_args, **kwds)
future = executor.submit(target_fn, jit_context)
latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError:
logging.error(f"Timeout exceeded for config {config}. Skipping this configuration.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment