"src/include/threadwise_direct_convolution.hip.hpp" did not exist on "39775d484c4d15a5b895edfc9d2323f05ab2d3d4"
Commit 39ae28e4 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

Revert "[Bugfix] Use AutoTune cache_input_tensors properly (#483)" (#488)

This reverts commit 22e6de184fa4b307640b108b779f3d46d132f96c.
parent a10882e0
......@@ -249,19 +249,6 @@ class AutoTuner:
if self.jit_compile is None:
self.jit_compile = _compile
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(supply_prog, profiler, with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
def target_fn(jit_context: JITContext):
# Unpack the context
kernel = jit_context.kernel
......@@ -277,30 +264,57 @@ class AutoTuner:
profiler = kernel.get_profiler(tensor_supply_type=supply_type)
if cache_input_tensors and self.jit_input_tensors is not None:
jit_input_tensors = self.jit_input_tensors
else:
jit_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
jit_input_tensors = jit_input_tensors_supply()
if self.jit_input_tensors is not None:
if not check_tensor_list_compatibility(self.jit_input_tensors,
jit_input_tensors):
logger.warning(
"Incompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
self.jit_input_tensors = jit_input_tensors
self.jit_input_tensors = jit_input_tensors
else:
self.jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None:
profiler.manual_assert_close(
ref_prog,
input_tensors=jit_input_tensors,
input_tensors=self.jit_input_tensors,
manual_check_prog=manual_check_prog)
else:
profiler.assert_allclose(
ref_prog,
input_tensors=jit_input_tensors,
input_tensors=self.jit_input_tensors,
rtol=rtol,
atol=atol,
max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors)
latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None:
ref_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
......@@ -353,14 +367,6 @@ class AutoTuner:
continue
ref_latency = None
if results_with_configs[0][0].cache_input_tensors:
supply_prog = results_with_configs[0][0].supply_prog
supply_type = results_with_configs[0][0].supply_type
profiler = results_with_configs[0][0].kernel.get_profiler(
tensor_supply_type=supply_type)
jit_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
self.jit_input_tensors = jit_input_tensors_supply()
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar:
jit_context, config = results_with_configs[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment