Unverified Commit af769881 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] introduce `VAETesterMixin` to consolidate tests for slicing and tiling (#12374)

* up

* up

* up

* up

* up

* u[

* up

* up

* up
parent 4715c5c7
......@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
......
......@@ -17,13 +17,14 @@ import unittest
from diffusers import AutoencoderKLCosmos
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
......@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
......@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
......@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
......@@ -20,18 +20,15 @@ import torch
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
......@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
......
......@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
......@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
......
......@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
......@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
......
......@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
......@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
......@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
......@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.
pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
......@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
......@@ -18,13 +18,14 @@ import unittest
from diffusers import AutoencoderKLMagvit
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
......@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
)
def test_enable_disable_slicing(self):
pass
......@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderKLMochi
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2
......@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass
@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""
......
......@@ -30,13 +30,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
......@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return
......
......@@ -31,13 +31,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny
main_input_name = "sample"
base_precision = 1e-2
......@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_enable_disable_tiling(self):
pass
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict)[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict)[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test not supported.")
def test_outputs_equivalence(self):
pass
......
......@@ -15,18 +15,17 @@
import unittest
import torch
from diffusers import AutoencoderKLWan
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
......@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling(96, 96, 64, 64)
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.05,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
......
......@@ -31,12 +31,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
......@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
......
......@@ -19,19 +19,15 @@ import torch
from diffusers import VQModel
from ...testing_utils import (
backend_manual_seed,
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel
main_input_name = "sample"
......
import inspect
import numpy as np
import pytest
import torch
from diffusers.models.autoencoders.vae import DecoderOutput
from diffusers.utils.torch_utils import torch_device
class AutoencoderTesterMixin:
"""
Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
usually don't do slicing and tiling.
"""
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
torch.manual_seed(0)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_tiling, DecoderOutput):
output_without_tiling = output_without_tiling.sample
torch.manual_seed(0)
model.enable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
if isinstance(output_with_tiling, DecoderOutput):
output_with_tiling = output_with_tiling.sample
assert (
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
).max() < 0.5, "VAE tiling should not affect the inference results"
torch.manual_seed(0)
model.disable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
if isinstance(output_without_tiling_2, DecoderOutput):
output_without_tiling_2 = output_without_tiling_2.sample
assert np.allclose(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
), "Without tiling outputs should match with the outputs when tiling is manually disabled."
def test_enable_disable_slicing(self):
if not hasattr(self.model_class, "enable_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_slicing, DecoderOutput):
output_without_slicing = output_without_slicing.sample
torch.manual_seed(0)
model.enable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
if isinstance(output_with_slicing, DecoderOutput):
output_with_slicing = output_with_slicing.sample
assert (
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
).max() < 0.5, "VAE slicing should not affect the inference results"
torch.manual_seed(0)
model.disable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
if isinstance(output_without_slicing_2, DecoderOutput):
output_without_slicing_2 = output_without_slicing_2.sample
assert np.allclose(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
......@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
class UNetTesterMixin:
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment