"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a284f71d11f6f6e52fe63439b9eeef286d4e1e47"
Commit a10882e0 authored by yyttt6's avatar yyttt6 Committed by LeiWang1999
Browse files

[Bugfix] Use AutoTune cache_input_tensors properly (#483)

parent fa0fca58
...@@ -249,6 +249,19 @@ class AutoTuner: ...@@ -249,6 +249,19 @@ class AutoTuner:
if self.jit_compile is None: if self.jit_compile is None:
self.jit_compile = _compile self.jit_compile = _compile
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(supply_prog, profiler, with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
def target_fn(jit_context: JITContext): def target_fn(jit_context: JITContext):
# Unpack the context # Unpack the context
kernel = jit_context.kernel kernel = jit_context.kernel
...@@ -264,57 +277,30 @@ class AutoTuner: ...@@ -264,57 +277,30 @@ class AutoTuner:
profiler = kernel.get_profiler(tensor_supply_type=supply_type) profiler = kernel.get_profiler(tensor_supply_type=supply_type)
# Factory functions for generating input tensors. if cache_input_tensors and self.jit_input_tensors is not None:
# This encapsulates the logic of using either a custom supply program (`supply_prog`) jit_input_tensors = self.jit_input_tensors
# or the default profiler input generation (`profiler._get_inputs`).
def get_input_tensors_supply(with_output: bool):
def func():
if supply_prog is not None:
return supply_prog(profiler._get_params(with_output=with_output))
else:
return profiler._get_inputs(with_output=with_output)
return func
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
if cache_input_tensors:
jit_input_tensors = jit_input_tensors_supply()
if self.jit_input_tensors is not None:
if not check_tensor_list_compatibility(self.jit_input_tensors,
jit_input_tensors):
logger.warning(
"Incompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.\n"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:\n"
" `cache_input_tensors=False`\n"
"within your `.set_compile_args(...)` call.\n")
self.jit_input_tensors = jit_input_tensors
self.jit_input_tensors = jit_input_tensors
else: else:
self.jit_input_tensors = jit_input_tensors_supply() jit_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
jit_input_tensors = jit_input_tensors_supply()
if (not skip_check) and (ref_prog is not None): if (not skip_check) and (ref_prog is not None):
if manual_check_prog is not None: if manual_check_prog is not None:
profiler.manual_assert_close( profiler.manual_assert_close(
ref_prog, ref_prog,
input_tensors=self.jit_input_tensors, input_tensors=jit_input_tensors,
manual_check_prog=manual_check_prog) manual_check_prog=manual_check_prog)
else: else:
profiler.assert_allclose( profiler.assert_allclose(
ref_prog, ref_prog,
input_tensors=self.jit_input_tensors, input_tensors=jit_input_tensors,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
max_mismatched_ratio=max_mismatched_ratio) max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench( latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors)
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
ref_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
self.ref_input_tensors = ref_input_tensors_supply() self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
...@@ -367,6 +353,14 @@ class AutoTuner: ...@@ -367,6 +353,14 @@ class AutoTuner:
continue continue
ref_latency = None ref_latency = None
if results_with_configs[0][0].cache_input_tensors:
supply_prog = results_with_configs[0][0].supply_prog
supply_type = results_with_configs[0][0].supply_type
profiler = results_with_configs[0][0].kernel.get_profiler(
tensor_supply_type=supply_type)
jit_input_tensors_supply = get_input_tensors_supply(
supply_prog, profiler, with_output=False)
self.jit_input_tensors = jit_input_tensors_supply()
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
for i in progress_bar: for i in progress_bar:
jit_context, config = results_with_configs[i] jit_context, config = results_with_configs[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment