"vscode:/vscode.git/clone" did not exist on "493228a70835d8c4f4afd0c8507760d71dc43eae"
Unverified Commit 38ced7ee authored by tongyu's avatar tongyu Committed by GitHub
Browse files

[test_models_transformer_hunyuan_video] help us test torch.compile() for impactful models (#11431)



* Update test_models_transformer_hunyuan_video.py

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 23c98025
...@@ -17,7 +17,14 @@ import unittest ...@@ -17,7 +17,14 @@ import unittest
import torch import torch
from diffusers import HunyuanVideoTransformer3DModel from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
...@@ -89,6 +96,21 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): ...@@ -89,6 +96,21 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
...@@ -157,6 +179,21 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T ...@@ -157,6 +179,21 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
...@@ -223,6 +260,21 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test ...@@ -223,6 +260,21 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
...@@ -290,3 +342,18 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u ...@@ -290,3 +342,18 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment