Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ffc8c0c1
Unverified
Commit
ffc8c0c1
authored
Sep 03, 2025
by
Sayak Paul
Committed by
GitHub
Sep 03, 2025
Browse files
[tests] feat: add AoT compilation tests (#12203)
* feat: add a test for aot. * up
parent
4acbfbf1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
2 deletions
+24
-2
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+24
-2
No files found.
tests/models/test_modeling_common.py
View file @
ffc8c0c1
...
...
@@ -2059,6 +2059,7 @@ class TorchCompileTesterMixin:
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
model
.
eval
()
model
=
torch
.
compile
(
model
,
fullgraph
=
True
)
with
(
...
...
@@ -2076,6 +2077,7 @@ class TorchCompileTesterMixin:
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
model
.
eval
()
model
.
compile_repeated_blocks
(
fullgraph
=
True
)
recompile_limit
=
1
...
...
@@ -2098,7 +2100,6 @@ class TorchCompileTesterMixin:
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
eval
()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs
=
{
...
...
@@ -2111,11 +2112,11 @@ class TorchCompileTesterMixin:
}
model
.
enable_group_offload
(
**
group_offload_kwargs
)
model
.
compile
()
with
torch
.
no_grad
():
_
=
model
(
**
inputs_dict
)
_
=
model
(
**
inputs_dict
)
@
require_torch_version_greater
(
"2.7.1"
)
def
test_compile_on_different_shapes
(
self
):
if
self
.
different_shapes_for_compilation
is
None
:
pytest
.
skip
(
f
"Skipping as `different_shapes_for_compilation` is not set for
{
self
.
__class__
.
__name__
}
."
)
...
...
@@ -2123,6 +2124,7 @@ class TorchCompileTesterMixin:
init_dict
,
_
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
model
.
eval
()
model
=
torch
.
compile
(
model
,
fullgraph
=
True
,
dynamic
=
True
)
for
height
,
width
in
self
.
different_shapes_for_compilation
:
...
...
@@ -2130,6 +2132,26 @@ class TorchCompileTesterMixin:
inputs_dict
=
self
.
prepare_dummy_input
(
height
=
height
,
width
=
width
)
_
=
model
(
**
inputs_dict
)
def
test_compile_works_with_aot
(
self
):
from
torch._inductor.package
import
load_package
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
exported_model
=
torch
.
export
.
export
(
model
,
args
=
(),
kwargs
=
inputs_dict
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdir
:
package_path
=
os
.
path
.
join
(
tmpdir
,
f
"
{
self
.
model_class
.
__name__
}
.pt2"
)
_
=
torch
.
_inductor
.
aoti_compile_and_package
(
exported_model
,
package_path
=
package_path
)
assert
os
.
path
.
exists
(
package_path
)
loaded_binary
=
load_package
(
package_path
,
run_single_threaded
=
True
)
model
.
forward
=
loaded_binary
with
torch
.
no_grad
():
_
=
model
(
**
inputs_dict
)
_
=
model
(
**
inputs_dict
)
@
slow
@
require_torch_2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment