"...git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "ca998f9ca511fff94be1951bb4197116518c43ea"
Unverified Commit 16c955c5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] add test for torch.compile + group offloading (#11670)

* add a test for group offloading + compilation.

* tests
parent 0f91f2f6
...@@ -1744,6 +1744,10 @@ class ModelPushToHubTester(unittest.TestCase): ...@@ -1744,6 +1744,10 @@ class ModelPushToHubTester(unittest.TestCase):
delete_repo(self.repo_id, token=TOKEN) delete_repo(self.repo_id, token=TOKEN)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
class TorchCompileTesterMixin: class TorchCompileTesterMixin:
def setUp(self): def setUp(self):
# clean up the VRAM before each test # clean up the VRAM before each test
...@@ -1759,12 +1763,7 @@ class TorchCompileTesterMixin: ...@@ -1759,12 +1763,7 @@ class TorchCompileTesterMixin:
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self): def test_torch_compile_recompilation_and_graph_break(self):
torch.compiler.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
...@@ -1778,6 +1777,31 @@ class TorchCompileTesterMixin: ...@@ -1778,6 +1777,31 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict) _ = model(**inputs_dict)
_ = model(**inputs_dict) _ = model(**inputs_dict)
def test_compile_with_group_offloading(self):
torch._dynamo.config.cache_size_limit = 10000
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
"onload_device": "cuda",
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@slow @slow
@require_torch_2 @require_torch_2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment