Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
a10882e0
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a284f71d11f6f6e52fe63439b9eeef286d4e1e47"
Commit
a10882e0
authored
May 12, 2025
by
yyttt6
Committed by
LeiWang1999
May 12, 2025
Browse files
[Bugfix] Use AutoTune cache_input_tensors properly (#483)
parent
fa0fca58
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
37 deletions
+31
-37
tilelang/autotuner/__init__.py
tilelang/autotuner/__init__.py
+31
-37
No files found.
tilelang/autotuner/__init__.py
View file @
a10882e0
...
@@ -249,6 +249,19 @@ class AutoTuner:
...
@@ -249,6 +249,19 @@ class AutoTuner:
if
self
.
jit_compile
is
None
:
if
self
.
jit_compile
is
None
:
self
.
jit_compile
=
_compile
self
.
jit_compile
=
_compile
# Factory functions for generating input tensors.
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
# or the default profiler input generation (`profiler._get_inputs`).
def
get_input_tensors_supply
(
supply_prog
,
profiler
,
with_output
:
bool
):
def
func
():
if
supply_prog
is
not
None
:
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
else
:
return
profiler
.
_get_inputs
(
with_output
=
with_output
)
return
func
def
target_fn
(
jit_context
:
JITContext
):
def
target_fn
(
jit_context
:
JITContext
):
# Unpack the context
# Unpack the context
kernel
=
jit_context
.
kernel
kernel
=
jit_context
.
kernel
...
@@ -264,57 +277,30 @@ class AutoTuner:
...
@@ -264,57 +277,30 @@ class AutoTuner:
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
supply_type
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
supply_type
)
# Factory functions for generating input tensors.
if
cache_input_tensors
and
self
.
jit_input_tensors
is
not
None
:
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
jit_input_tensors
=
self
.
jit_input_tensors
# or the default profiler input generation (`profiler._get_inputs`).
def
get_input_tensors_supply
(
with_output
:
bool
):
def
func
():
if
supply_prog
is
not
None
:
return
supply_prog
(
profiler
.
_get_params
(
with_output
=
with_output
))
else
:
return
profiler
.
_get_inputs
(
with_output
=
with_output
)
return
func
jit_input_tensors_supply
=
get_input_tensors_supply
(
with_output
=
False
)
ref_input_tensors_supply
=
get_input_tensors_supply
(
with_output
=
False
)
if
cache_input_tensors
:
jit_input_tensors
=
jit_input_tensors_supply
()
if
self
.
jit_input_tensors
is
not
None
:
if
not
check_tensor_list_compatibility
(
self
.
jit_input_tensors
,
jit_input_tensors
):
logger
.
warning
(
"Incompatible input tensor properties detected between cached tensors and "
"tensors regenerated for the current configuration trial. "
"This can happen if different tuning configurations require different input shapes/dtypes "
"and input tensor caching is enabled.
\n
"
"To ensure fresh, compatible inputs are generated for every trial "
"you can disable caching by setting:
\n
"
" `cache_input_tensors=False`
\n
"
"within your `.set_compile_args(...)` call.
\n
"
)
self
.
jit_input_tensors
=
jit_input_tensors
self
.
jit_input_tensors
=
jit_input_tensors
else
:
else
:
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
jit_input_tensors_supply
=
get_input_tensors_supply
(
supply_prog
,
profiler
,
with_output
=
False
)
jit_input_tensors
=
jit_input_tensors_supply
()
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
(
not
skip_check
)
and
(
ref_prog
is
not
None
):
if
manual_check_prog
is
not
None
:
if
manual_check_prog
is
not
None
:
profiler
.
manual_assert_close
(
profiler
.
manual_assert_close
(
ref_prog
,
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
input_tensors
=
jit_input_tensors
,
manual_check_prog
=
manual_check_prog
)
manual_check_prog
=
manual_check_prog
)
else
:
else
:
profiler
.
assert_allclose
(
profiler
.
assert_allclose
(
ref_prog
,
ref_prog
,
input_tensors
=
self
.
jit_input_tensors
,
input_tensors
=
jit_input_tensors
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
,
atol
=
atol
,
max_mismatched_ratio
=
max_mismatched_ratio
)
max_mismatched_ratio
=
max_mismatched_ratio
)
latency
=
profiler
.
do_bench
(
latency
=
profiler
.
do_bench
(
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
jit_input_tensors
)
warmup
=
warmup
,
rep
=
rep
,
input_tensors
=
self
.
jit_input_tensors
)
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
if
self
.
ref_latency_cache
is
None
and
ref_prog
is
not
None
:
ref_input_tensors_supply
=
get_input_tensors_supply
(
supply_prog
,
profiler
,
with_output
=
False
)
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_input_tensors
=
ref_input_tensors_supply
()
self
.
ref_latency_cache
=
profiler
.
do_bench
(
self
.
ref_latency_cache
=
profiler
.
do_bench
(
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
ref_prog
,
n_warmup
=
warmup
,
n_repeat
=
rep
,
input_tensors
=
self
.
ref_input_tensors
)
...
@@ -367,6 +353,14 @@ class AutoTuner:
...
@@ -367,6 +353,14 @@ class AutoTuner:
continue
continue
ref_latency
=
None
ref_latency
=
None
if
results_with_configs
[
0
][
0
].
cache_input_tensors
:
supply_prog
=
results_with_configs
[
0
][
0
].
supply_prog
supply_type
=
results_with_configs
[
0
][
0
].
supply_type
profiler
=
results_with_configs
[
0
][
0
].
kernel
.
get_profiler
(
tensor_supply_type
=
supply_type
)
jit_input_tensors_supply
=
get_input_tensors_supply
(
supply_prog
,
profiler
,
with_output
=
False
)
self
.
jit_input_tensors
=
jit_input_tensors_supply
()
progress_bar
=
tqdm
(
range
(
len
(
results_with_configs
)),
desc
=
"Bench configurations"
)
progress_bar
=
tqdm
(
range
(
len
(
results_with_configs
)),
desc
=
"Bench configurations"
)
for
i
in
progress_bar
:
for
i
in
progress_bar
:
jit_context
,
config
=
results_with_configs
[
i
]
jit_context
,
config
=
results_with_configs
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment