Unverified Commit af769881 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] introduce `VAETesterMixin` to consolidate tests for slicing and tiling (#12374)

* up

* up

* up

* up

* up

* u[

* up

* up

* up
parent 4715c5c7
...@@ -35,13 +35,14 @@ from ...testing_utils import ( ...@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL model_class = AsymmetricAutoencoderKL
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
......
...@@ -17,13 +17,14 @@ import unittest ...@@ -17,13 +17,14 @@ import unittest
from diffusers import AutoencoderKLCosmos from diffusers import AutoencoderKLCosmos
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos model_class = AutoencoderKLCosmos
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC ...@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Not sure why this test fails. Investigate later.") @unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self): def test_effective_gradient_checkpointing(self):
pass pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
...@@ -22,13 +22,14 @@ from ...testing_utils import ( ...@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor, floats_tensor,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC model_class = AutoencoderDC
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
init_dict = self.get_autoencoder_dc_config() init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
...@@ -20,18 +20,15 @@ import torch ...@@ -20,18 +20,15 @@ import torch
from diffusers import AutoencoderKLHunyuanVideo from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from ...testing_utils import ( from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
enable_full_determinism, from ..test_modeling_common import ModelTesterMixin
floats_tensor, from .testing_utils import AutoencoderTesterMixin
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest ...@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = { expected_set = {
"HunyuanVideoDecoder3D", "HunyuanVideoDecoder3D",
......
...@@ -35,13 +35,14 @@ from ...testing_utils import ( ...@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL model_class = AutoencoderKL
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"} expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
......
...@@ -24,13 +24,14 @@ from ...testing_utils import ( ...@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor, floats_tensor,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX model_class = AutoencoderKLCogVideoX
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te ...@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = { expected_set = {
"CogVideoXDownBlock3D", "CogVideoXDownBlock3D",
......
...@@ -22,13 +22,14 @@ from ...testing_utils import ( ...@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor, floats_tensor,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt ...@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"} expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
...@@ -24,13 +24,14 @@ from ...testing_utils import ( ...@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor, floats_tensor,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo model_class = AutoencoderKLLTXVideo
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest. ...@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.
pass pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo model_class = AutoencoderKLLTXVideo
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest. ...@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
pass pass
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
...@@ -18,13 +18,14 @@ import unittest ...@@ -18,13 +18,14 @@ import unittest
from diffusers import AutoencoderKLMagvit from diffusers import AutoencoderKLMagvit
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit model_class = AutoencoderKLMagvit
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC ...@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Unsupported test.") @unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
pass pass
@unittest.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
)
def test_enable_disable_slicing(self):
pass
...@@ -17,18 +17,15 @@ import unittest ...@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderKLMochi from diffusers import AutoencoderKLMochi
from ...testing_utils import ( from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
enable_full_determinism, from ..test_modeling_common import ModelTesterMixin
floats_tensor, from .testing_utils import AutoencoderTesterMixin
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi model_class = AutoencoderKLMochi
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa ...@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
} }
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass
@unittest.skip("Unsupported test.") @unittest.skip("Unsupported test.")
def test_model_parallelism(self): def test_model_parallelism(self):
""" """
......
...@@ -30,13 +30,14 @@ from ...testing_utils import ( ...@@ -30,13 +30,14 @@ from ...testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck model_class = AutoencoderOobleck
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa ...@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
"Without slicing outputs should match with the outputs when slicing is manually disabled.", "Without slicing outputs should match with the outputs when slicing is manually disabled.",
) )
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("No attention module used in this model") @unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
return return
......
...@@ -31,13 +31,14 @@ from ...testing_utils import ( ...@@ -31,13 +31,14 @@ from ...testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny model_class = AutoencoderTiny
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) ...@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_enable_disable_tiling(self): def test_enable_disable_tiling(self):
pass pass
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict)[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict)[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test not supported.") @unittest.skip("Test not supported.")
def test_outputs_equivalence(self): def test_outputs_equivalence(self):
pass pass
......
...@@ -15,18 +15,17 @@ ...@@ -15,18 +15,17 @@
import unittest import unittest
import torch
from diffusers import AutoencoderKLWan from diffusers import AutoencoderKLWan
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan model_class = AutoencoderKLWan
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase ...@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
inputs_dict = self.dummy_input_tiling inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling(96, 96, 64, 64)
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.05,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Gradient checkpointing has not been implemented yet") @unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
pass pass
......
...@@ -31,12 +31,13 @@ from ...testing_utils import ( ...@@ -31,12 +31,13 @@ from ...testing_utils import (
torch_device, torch_device,
) )
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism() enable_full_determinism()
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE model_class = ConsistencyDecoderVAE
main_input_name = "sample" main_input_name = "sample"
base_precision = 1e-2 base_precision = 1e-2
...@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): ...@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict() return self.init_dict, self.inputs_dict()
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@slow @slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
......
...@@ -19,19 +19,15 @@ import torch ...@@ -19,19 +19,15 @@ import torch
from diffusers import VQModel from diffusers import VQModel
from ...testing_utils import ( from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
backend_manual_seed, from ..test_modeling_common import ModelTesterMixin
enable_full_determinism, from .testing_utils import AutoencoderTesterMixin
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism() enable_full_determinism()
class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel model_class = VQModel
main_input_name = "sample" main_input_name = "sample"
......
import inspect
import numpy as np
import pytest
import torch
from diffusers.models.autoencoders.vae import DecoderOutput
from diffusers.utils.torch_utils import torch_device
class AutoencoderTesterMixin:
"""
Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
usually don't do slicing and tiling.
"""
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
torch.manual_seed(0)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_tiling, DecoderOutput):
output_without_tiling = output_without_tiling.sample
torch.manual_seed(0)
model.enable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
if isinstance(output_with_tiling, DecoderOutput):
output_with_tiling = output_with_tiling.sample
assert (
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
).max() < 0.5, "VAE tiling should not affect the inference results"
torch.manual_seed(0)
model.disable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
if isinstance(output_without_tiling_2, DecoderOutput):
output_without_tiling_2 = output_without_tiling_2.sample
assert np.allclose(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
), "Without tiling outputs should match with the outputs when tiling is manually disabled."
def test_enable_disable_slicing(self):
if not hasattr(self.model_class, "enable_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_slicing, DecoderOutput):
output_without_slicing = output_without_slicing.sample
torch.manual_seed(0)
model.enable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
if isinstance(output_with_slicing, DecoderOutput):
output_with_slicing = output_with_slicing.sample
assert (
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
).max() < 0.5, "VAE slicing should not affect the inference results"
torch.manual_seed(0)
model.disable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
if isinstance(output_without_slicing_2, DecoderOutput):
output_without_slicing_2 = output_without_slicing_2.sample
assert np.allclose(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
...@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
class UNetTesterMixin: class UNetTesterMixin:
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16 init_dict["norm_num_groups"] = 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment