Commit 18be9e07 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Carver] Multi-Threads Compilation for Fast Auto Tuning (#156)

* [Carver] Multi-Threads Compilation for Fast Auto Tuning

* Add progress bar for compilation

* lint
parent 782ca9f6
...@@ -167,7 +167,7 @@ def matmul(M, N, K, with_roller): ...@@ -167,7 +167,7 @@ def matmul(M, N, K, with_roller):
"enable_rasteration", "enable_rasteration",
], ],
warmup=3, warmup=3,
rep=5, rep=20,
) )
@jit( @jit(
out_idx=[2], out_idx=[2],
......
...@@ -11,6 +11,8 @@ from tqdm import tqdm ...@@ -11,6 +11,8 @@ from tqdm import tqdm
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
import concurrent.futures import concurrent.futures
import os
from functools import partial
logging.basicConfig( logging.basicConfig(
filename='out.log', filename='out.log',
...@@ -56,6 +58,10 @@ class Autotuner: ...@@ -56,6 +58,10 @@ class Autotuner:
self.jit_input_tensors = None self.jit_input_tensors = None
self.ref_input_tensors = None self.ref_input_tensors = None
def jit_compile(self, args: Any, **kwds: Any) -> JITContext:
jit_context = self.fn(*args, **kwds)
return jit_context
def run(self, *args: Any, **kwds: Any) -> Any: def run(self, *args: Any, **kwds: Any) -> Any:
sig = inspect.signature(self.fn) sig = inspect.signature(self.fn)
bound_args = sig.bind(*args, **kwds) bound_args = sig.bind(*args, **kwds)
...@@ -64,9 +70,7 @@ class Autotuner: ...@@ -64,9 +70,7 @@ class Autotuner:
best_latency = 1e8 best_latency = 1e8
best_config = None best_config = None
def target_fn(*new_args, **kwds): def target_fn(jit_context):
jit_context = self.fn(*new_args, **kwds)
# Unpack the context # Unpack the context
mod = jit_context.mod mod = jit_context.mod
profiler = jit_context.profiler profiler = jit_context.profiler
...@@ -102,8 +106,11 @@ class Autotuner: ...@@ -102,8 +106,11 @@ class Autotuner:
return latency, self.ref_latency_cache return latency, self.ref_latency_cache
progress_bar = tqdm(self.configs, desc="Running configurations") # Parallel compilation
for config in progress_bar: config_args = []
jit_contexts = []
for config in self.configs:
new_args = [] new_args = []
for name, value in bound_args.arguments.items(): for name, value in bound_args.arguments.items():
if name not in self.keys: if name not in self.keys:
...@@ -111,11 +118,33 @@ class Autotuner: ...@@ -111,11 +118,33 @@ class Autotuner:
else: else:
new_args.append(config[name]) new_args.append(config[name])
new_args = tuple(new_args) new_args = tuple(new_args)
ref_latency = None config_args.append(new_args)
worker = partial(
self.jit_compile,
**kwds,
)
# 90% utilization
num_workers = max(1, int(os.cpu_count() * 0.9))
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
results = tqdm(
pool.map(
worker,
config_args,
), desc="Compiling configurations")
for result in results:
jit_contexts.append(result)
ref_latency = None
progress_bar = tqdm(range(len(config_args)), desc="Bench configurations")
for i in progress_bar:
jit_context = jit_contexts[i]
config = config_args[i]
try: try:
# Use ThreadPoolExecutor to enforce timeout on target_fn execution # Use ThreadPoolExecutor to enforce timeout on target_fn execution
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(target_fn, *new_args, **kwds) future = executor.submit(target_fn, jit_context)
latency, ref_latency = future.result(timeout=self.timeout) latency, ref_latency = future.result(timeout=self.timeout)
except concurrent.futures.TimeoutError: except concurrent.futures.TimeoutError:
logging.error(f"Timeout exceeded for config {config}. Skipping this configuration.") logging.error(f"Timeout exceeded for config {config}. Skipping this configuration.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment