Commit d627fd58 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix compilation issues for amd cdna element size check (#364)

* [Refactor] Update AutoTuner run method and timeout handling

- Modified the `run` method to reduce the default timeout from 100 to 30 seconds for improved responsiveness.
- Changed the `get_input_tensors_supply` call to disable output generation, enhancing performance during tensor supply retrieval.
- Refactored the latency measurement to streamline the benchmarking process, ensuring proper timeout handling with `ThreadPoolExecutor`.
- Added logging for timeout occurrences to aid in debugging and performance analysis.

* bug fix

* lint fix
parent e3065f0b
......@@ -170,7 +170,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, bool transposed) {
const int warp_n, const int element_size,
bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
......
......@@ -146,7 +146,8 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, bool transposed = false);
const int warp_n, const int element_size,
bool transposed = false);
// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
......
......@@ -275,8 +275,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
A->dtype.bits(), kPack);
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
results.Set(
A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A));
} else {
ICHECK(0);
}
......
......@@ -172,7 +172,7 @@ class AutoTuner:
self.jit_compile = _compile
return self
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100):
def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
"""Run the auto-tuning process.
Args:
......@@ -218,7 +218,7 @@ class AutoTuner:
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=(profiler == "tvm"))
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
......@@ -247,9 +247,8 @@ class AutoTuner:
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(
profiler.func, n_warmup=warmup, n_repeat=rep, input_tensors=self.jit_input_tensors)
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
......@@ -304,7 +303,15 @@ class AutoTuner:
try:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread
latency, ref_latency = target_fn(jit_context)
# latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = benchmark_executor.submit(target_fn, jit_context)
latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
continue
except Exception as e:
logger.info(
f"An error occurred while testing config {config}, checkout autotuner.log for more details"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment