test_modular_pipelines_common.py 13.5 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
import gc
import tempfile
from typing import Callable, Union

5
import pytest
YiYi Xu's avatar
YiYi Xu committed
6
7
8
9
import torch

import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
10
from diffusers.guiders import ClassifierFreeGuidance
YiYi Xu's avatar
YiYi Xu committed
11
from diffusers.utils import logging
12

13
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
YiYi Xu's avatar
YiYi Xu committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class ModularPipelineTesterMixin:
    """
    It provides a set of common tests for each modular pipeline,
    including:
    - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
    - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
    - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
    - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
    - test_to_device: check if the pipeline's __call__ method can handle different devices
    """

    # Canonical parameters that are passed to `__call__` regardless
    # of the type of pipeline. They are always optional and have common
    # sense default values.
30
    optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
YiYi Xu's avatar
YiYi Xu committed
31
    # this is modular specific: generator needs to be a intermediate input because it's mutable
32
    intermediate_params = frozenset(["generator"])
YiYi Xu's avatar
YiYi Xu committed
33

34
35
    def get_generator(self, seed=0):
        generator = torch.Generator("cpu").manual_seed(seed)
YiYi Xu's avatar
YiYi Xu committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        return generator

    @property
    def pipeline_class(self) -> Union[Callable, ModularPipeline]:
        raise NotImplementedError(
            "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
            "See existing pipeline tests for reference."
        )

    @property
    def repo(self) -> str:
        raise NotImplementedError(
            "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
        )

    @property
    def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
        raise NotImplementedError(
            "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
            "See existing pipeline tests for reference."
        )

58
    def get_dummy_inputs(self, seed=0):
YiYi Xu's avatar
YiYi Xu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        raise NotImplementedError(
            "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
            "See existing pipeline tests for reference."
        )

    @property
    def params(self) -> frozenset:
        raise NotImplementedError(
            "You need to set the attribute `params` in the child test class. "
            "`params` are checked for if all values are present in `__call__`'s signature."
            " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
            " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to  "
            "image pipelines, including prompts and prompt embedding overrides."
            "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
            "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
            "with non-configurable height and width arguments should set the attribute as "
            "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
            "See existing pipeline tests for reference."
        )

    @property
    def batch_params(self) -> frozenset:
        raise NotImplementedError(
            "You need to set the attribute `batch_params` in the child test class. "
            "`batch_params` are the parameters required to be batched when passed to the pipeline's "
            "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
            "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
            "set of batch arguments has minor changes from one of the common sets of batch arguments, "
            "do not make modifications to the existing common sets of batch arguments. I.e. a text to "
            "image pipeline `negative_prompt` is not batched should set the attribute as "
            "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
            "See existing pipeline tests for reference."
        )

93
    def setup_method(self):
YiYi Xu's avatar
YiYi Xu committed
94
95
96
97
98
        # clean up the VRAM before each test
        torch.compiler.reset()
        gc.collect()
        backend_empty_cache(torch_device)

99
    def teardown_method(self):
YiYi Xu's avatar
YiYi Xu committed
100
101
102
103
104
        # clean up the VRAM after each test in case of CUDA runtime errors
        torch.compiler.reset()
        gc.collect()
        backend_empty_cache(torch_device)

105
106
107
    def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
        pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
        pipeline.load_components(torch_dtype=torch_dtype)
108
        pipeline.set_progress_bar_config(disable=None)
109
110
        return pipeline

YiYi Xu's avatar
YiYi Xu committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    def test_pipeline_call_signature(self):
        pipe = self.get_pipeline()
        input_parameters = pipe.blocks.input_names
        optional_parameters = pipe.default_call_parameters

        def _check_for_parameters(parameters, expected_parameters, param_type):
            remaining_parameters = {param for param in parameters if param not in expected_parameters}
            assert len(remaining_parameters) == 0, (
                f"Required {param_type} parameters not present: {remaining_parameters}"
            )

        _check_for_parameters(self.params, input_parameters, "input")
        _check_for_parameters(self.optional_params, optional_parameters, "optional")

    def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
126
        pipe = self.get_pipeline().to(torch_device)
YiYi Xu's avatar
YiYi Xu committed
127

128
        inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        inputs["generator"] = self.get_generator(0)

        logger = logging.get_logger(pipe.__module__)
        logger.setLevel(level=diffusers.logging.FATAL)

        # prepare batched inputs
        batched_inputs = []
        for batch_size in batch_sizes:
            batched_input = {}
            batched_input.update(inputs)

            for name in self.batch_params:
                if name not in inputs:
                    continue

                value = inputs[name]
                batched_input[name] = batch_size * [value]

            if batch_generator and "generator" in inputs:
                batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]

            if "batch_size" in inputs:
                batched_input["batch_size"] = batch_size

            batched_inputs.append(batched_input)

        logger.setLevel(level=diffusers.logging.WARNING)
        for batch_size, batched_input in zip(batch_sizes, batched_inputs):
            output = pipe(**batched_input, output="images")
            assert len(output) == batch_size, "Output is different from expected batch size"

    def test_inference_batch_single_identical(
        self,
        batch_size=2,
        expected_max_diff=1e-4,
    ):
165
166
        pipe = self.get_pipeline().to(torch_device)

167
        inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

        # Reset generator in case it is has been used in self.get_dummy_inputs
        inputs["generator"] = self.get_generator(0)

        logger = logging.get_logger(pipe.__module__)
        logger.setLevel(level=diffusers.logging.FATAL)

        # batchify inputs
        batched_inputs = {}
        batched_inputs.update(inputs)

        for name in self.batch_params:
            if name not in inputs:
                continue

            value = inputs[name]
            batched_inputs[name] = batch_size * [value]

        if "generator" in inputs:
            batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]

        if "batch_size" in inputs:
            batched_inputs["batch_size"] = batch_size

        output = pipe(**inputs, output="images")
        output_batch = pipe(**batched_inputs, output="images")

        assert output_batch.shape[0] == batch_size

197
        max_diff = torch.abs(output_batch[0] - output[0]).max()
YiYi Xu's avatar
YiYi Xu committed
198
199
200
201
202
203
204
205
206
207
        assert max_diff < expected_max_diff, "Batch inference results different from single inference results"

    @require_accelerator
    def test_float16_inference(self, expected_max_diff=5e-2):
        pipe = self.get_pipeline()
        pipe.to(torch_device, torch.float32)

        pipe_fp16 = self.get_pipeline()
        pipe_fp16.to(torch_device, torch.float16)

208
        inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
209
210
211
212
213
        # Reset generator in case it is used inside dummy inputs
        if "generator" in inputs:
            inputs["generator"] = self.get_generator(0)
        output = pipe(**inputs, output="images")

214
        fp16_inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
215
216
217
218
219
        # Reset generator in case it is used inside dummy inputs
        if "generator" in fp16_inputs:
            fp16_inputs["generator"] = self.get_generator(0)
        output_fp16 = pipe_fp16(**fp16_inputs, output="images")

220
221
        output = output.cpu()
        output_fp16 = output_fp16.cpu()
YiYi Xu's avatar
YiYi Xu committed
222
223
224
225
226
227

        max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
        assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"

    @require_accelerator
    def test_to_device(self):
228
        pipe = self.get_pipeline().to("cpu")
YiYi Xu's avatar
YiYi Xu committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

        model_devices = [
            component.device.type for component in pipe.components.values() if hasattr(component, "device")
        ]
        assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"

        pipe.to(torch_device)
        model_devices = [
            component.device.type for component in pipe.components.values() if hasattr(component, "device")
        ]
        assert all(device == torch_device for device in model_devices), (
            "All pipeline components are not on accelerator device"
        )

    def test_inference_is_not_nan_cpu(self):
244
        pipe = self.get_pipeline().to("cpu")
YiYi Xu's avatar
YiYi Xu committed
245

246
247
        output = pipe(**self.get_dummy_inputs(), output="images")
        assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
YiYi Xu's avatar
YiYi Xu committed
248
249
250

    @require_accelerator
    def test_inference_is_not_nan(self):
251
        pipe = self.get_pipeline().to(torch_device)
YiYi Xu's avatar
YiYi Xu committed
252

253
254
        output = pipe(**self.get_dummy_inputs(), output="images")
        assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
YiYi Xu's avatar
YiYi Xu committed
255
256

    def test_num_images_per_prompt(self):
257
        pipe = self.get_pipeline().to(torch_device)
YiYi Xu's avatar
YiYi Xu committed
258
259

        if "num_images_per_prompt" not in pipe.blocks.input_names:
260
            pytest.mark.skip("Skipping test as `num_images_per_prompt` is not present in input names.")
YiYi Xu's avatar
YiYi Xu committed
261
262
263
264
265
266

        batch_sizes = [1, 2]
        num_images_per_prompts = [1, 2]

        for batch_size in batch_sizes:
            for num_images_per_prompt in num_images_per_prompts:
267
                inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

                for key in inputs.keys():
                    if key in self.batch_params:
                        inputs[key] = batch_size * [inputs[key]]

                images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")

                assert images.shape[0] == batch_size * num_images_per_prompt

    @require_accelerator
    def test_components_auto_cpu_offload_inference_consistent(self):
        base_pipe = self.get_pipeline().to(torch_device)

        cm = ComponentsManager()
        cm.enable_auto_cpu_offload(device=torch_device)
        offload_pipe = self.get_pipeline(components_manager=cm)

        image_slices = []
        for pipe in [base_pipe, offload_pipe]:
287
            inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
288
289
290
291
            image = pipe(**inputs, output="images")

            image_slices.append(image[0, -3:, -3:, -1].flatten())

292
        assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
YiYi Xu's avatar
YiYi Xu committed
293
294
295
296
297
298
299
300
301

    def test_save_from_pretrained(self):
        pipes = []
        base_pipe = self.get_pipeline().to(torch_device)
        pipes.append(base_pipe)

        with tempfile.TemporaryDirectory() as tmpdirname:
            base_pipe.save_pretrained(tmpdirname)
            pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
302
            pipe.load_components(torch_dtype=torch.float32)
YiYi Xu's avatar
YiYi Xu committed
303
304
305
306
307
308
            pipe.to(torch_device)

        pipes.append(pipe)

        image_slices = []
        for pipe in pipes:
309
            inputs = self.get_dummy_inputs()
YiYi Xu's avatar
YiYi Xu committed
310
311
312
313
            image = pipe(**inputs, output="images")

            image_slices.append(image[0, -3:, -3:, -1].flatten())

314
        assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336


class ModularGuiderTesterMixin:
    def test_guider_cfg(self, expected_max_diff=1e-2):
        pipe = self.get_pipeline().to(torch_device)

        # forward pass with CFG not applied
        guider = ClassifierFreeGuidance(guidance_scale=1.0)
        pipe.update_components(guider=guider)

        inputs = self.get_dummy_inputs()
        out_no_cfg = pipe(**inputs, output="images")

        # forward pass with CFG applied
        guider = ClassifierFreeGuidance(guidance_scale=7.5)
        pipe.update_components(guider=guider)
        inputs = self.get_dummy_inputs()
        out_cfg = pipe(**inputs, output="images")

        assert out_cfg.shape == out_no_cfg.shape
        max_diff = torch.abs(out_cfg - out_no_cfg).max()
        assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"