testing_utils.py 5.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import inspect

import numpy as np
import pytest
import torch

from diffusers.models.autoencoders.vae import DecoderOutput
from diffusers.utils.torch_utils import torch_device


class AutoencoderTesterMixin:
    """
    Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
    usually don't do slicing and tiling.
    """

    @staticmethod
    def _accepts_generator(model):
        model_sig = inspect.signature(model.forward)
        accepts_generator = "generator" in model_sig.parameters
        return accepts_generator

    @staticmethod
    def _accepts_norm_num_groups(model_class):
        model_sig = inspect.signature(model_class.__init__)
        accepts_norm_groups = "norm_num_groups" in model_sig.parameters
        return accepts_norm_groups

    def test_forward_with_norm_groups(self):
        if not self._accepts_norm_num_groups(self.model_class):
            pytest.skip(f"Test not supported for {self.model_class.__name__}")
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["norm_num_groups"] = 16
        init_dict["block_out_channels"] = (16, 32)

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

    def test_enable_disable_tiling(self):
        if not hasattr(self.model_class, "enable_tiling"):
            pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")

        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        torch.manual_seed(0)
        model = self.model_class(**init_dict).to(torch_device)

60
61
62
        if not hasattr(model, "use_tiling"):
            pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        inputs_dict.update({"return_dict": False})
        _ = inputs_dict.pop("generator", None)
        accepts_generator = self._accepts_generator(model)

        torch.manual_seed(0)
        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)
        output_without_tiling = model(**inputs_dict)[0]
        # Mochi-1
        if isinstance(output_without_tiling, DecoderOutput):
            output_without_tiling = output_without_tiling.sample

        torch.manual_seed(0)
        model.enable_tiling()
        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)
        output_with_tiling = model(**inputs_dict)[0]
        if isinstance(output_with_tiling, DecoderOutput):
            output_with_tiling = output_with_tiling.sample

        assert (
            output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
        ).max() < 0.5, "VAE tiling should not affect the inference results"

        torch.manual_seed(0)
        model.disable_tiling()
        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)
        output_without_tiling_2 = model(**inputs_dict)[0]
        if isinstance(output_without_tiling_2, DecoderOutput):
            output_without_tiling_2 = output_without_tiling_2.sample

        assert np.allclose(
            output_without_tiling.detach().cpu().numpy().all(),
            output_without_tiling_2.detach().cpu().numpy().all(),
        ), "Without tiling outputs should match with the outputs when tiling is manually disabled."

    def test_enable_disable_slicing(self):
        if not hasattr(self.model_class, "enable_slicing"):
            pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")

        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        torch.manual_seed(0)
        model = self.model_class(**init_dict).to(torch_device)
108
109
        if not hasattr(model, "use_slicing"):
            pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

        inputs_dict.update({"return_dict": False})
        _ = inputs_dict.pop("generator", None)
        accepts_generator = self._accepts_generator(model)

        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)

        torch.manual_seed(0)
        output_without_slicing = model(**inputs_dict)[0]
        # Mochi-1
        if isinstance(output_without_slicing, DecoderOutput):
            output_without_slicing = output_without_slicing.sample

        torch.manual_seed(0)
        model.enable_slicing()
        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)
        output_with_slicing = model(**inputs_dict)[0]
        if isinstance(output_with_slicing, DecoderOutput):
            output_with_slicing = output_with_slicing.sample

        assert (
            output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
        ).max() < 0.5, "VAE slicing should not affect the inference results"

        torch.manual_seed(0)
        model.disable_slicing()
        if accepts_generator:
            inputs_dict["generator"] = torch.manual_seed(0)
        output_without_slicing_2 = model(**inputs_dict)[0]
        if isinstance(output_without_slicing_2, DecoderOutput):
            output_without_slicing_2 = output_without_slicing_2.sample

        assert np.allclose(
            output_without_slicing.detach().cpu().numpy().all(),
            output_without_slicing_2.detach().cpu().numpy().all(),
        ), "Without slicing outputs should match with the outputs when slicing is manually disabled."