Unverified Commit ffc8c0c1 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] feat: add AoT compilation tests (#12203)

* feat: add a test for aot.

* up
parent 4acbfbf1
...@@ -2059,6 +2059,7 @@ class TorchCompileTesterMixin: ...@@ -2059,6 +2059,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True) model = torch.compile(model, fullgraph=True)
with ( with (
...@@ -2076,6 +2077,7 @@ class TorchCompileTesterMixin: ...@@ -2076,6 +2077,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True) model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1 recompile_limit = 1
...@@ -2098,7 +2100,6 @@ class TorchCompileTesterMixin: ...@@ -2098,7 +2100,6 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.eval() model.eval()
# TODO: Can test for other group offloading kwargs later if needed. # TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = { group_offload_kwargs = {
...@@ -2111,11 +2112,11 @@ class TorchCompileTesterMixin: ...@@ -2111,11 +2112,11 @@ class TorchCompileTesterMixin:
} }
model.enable_group_offload(**group_offload_kwargs) model.enable_group_offload(**group_offload_kwargs)
model.compile() model.compile()
with torch.no_grad(): with torch.no_grad():
_ = model(**inputs_dict) _ = model(**inputs_dict)
_ = model(**inputs_dict) _ = model(**inputs_dict)
@require_torch_version_greater("2.7.1")
def test_compile_on_different_shapes(self): def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None: if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
...@@ -2123,6 +2124,7 @@ class TorchCompileTesterMixin: ...@@ -2123,6 +2124,7 @@ class TorchCompileTesterMixin:
init_dict, _ = self.prepare_init_args_and_inputs_for_common() init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True) model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation: for height, width in self.different_shapes_for_compilation:
...@@ -2130,6 +2132,26 @@ class TorchCompileTesterMixin: ...@@ -2130,6 +2132,26 @@ class TorchCompileTesterMixin:
inputs_dict = self.prepare_dummy_input(height=height, width=width) inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**inputs_dict) _ = model(**inputs_dict)
def test_compile_works_with_aot(self):
from torch._inductor.package import load_package
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path)
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@slow @slow
@require_torch_2 @require_torch_2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment