test_alt_diffusion.py 9.24 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
Suraj Patil's avatar
Suraj Patil committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import numpy as np
import torch
21
from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer
Suraj Patil's avatar
Suraj Patil committed
22
23
24
25
26
27

from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
    RobertaSeriesConfig,
    RobertaSeriesModelWithTransformation,
)
Dhruv Nair's avatar
Dhruv Nair committed
28
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
Suraj Patil's avatar
Suraj Patil committed
29

30
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
31
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
Suraj Patil's avatar
Suraj Patil committed
32
33


34
enable_full_determinism()
Suraj Patil's avatar
Suraj Patil committed
35
36


37
38
39
class AltDiffusionPipelineFastTests(
    PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
40
    pipeline_class = AltDiffusionPipeline
41
42
    params = TEXT_TO_IMAGE_PARAMS
    batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
43
    image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
44
    image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
Suraj Patil's avatar
Suraj Patil committed
45

46
    def get_dummy_components(self):
Suraj Patil's avatar
Suraj Patil committed
47
        torch.manual_seed(0)
48
        unet = UNet2DConditionModel(
Suraj Patil's avatar
Suraj Patil committed
49
50
51
52
53
54
55
56
57
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
58
59
60
61
62
63
        scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
Suraj Patil's avatar
Suraj Patil committed
64
65
        )
        torch.manual_seed(0)
66
        vae = AutoencoderKL(
Suraj Patil's avatar
Suraj Patil committed
67
68
69
70
71
72
73
74
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )

75
76
77
78
79
80
81
82
83
84
85
86
87
        # TODO: address the non-deterministic text encoder (fails for save-load tests)
        # torch.manual_seed(0)
        # text_encoder_config = RobertaSeriesConfig(
        #     hidden_size=32,
        #     project_dim=32,
        #     intermediate_size=37,
        #     layer_norm_eps=1e-05,
        #     num_attention_heads=4,
        #     num_hidden_layers=5,
        #     vocab_size=5002,
        # )
        # text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)

Suraj Patil's avatar
Suraj Patil committed
88
        torch.manual_seed(0)
89
90
91
        text_encoder_config = CLIPTextConfig(
            bos_token_id=0,
            eos_token_id=2,
Suraj Patil's avatar
Suraj Patil committed
92
            hidden_size=32,
93
            projection_dim=32,
Suraj Patil's avatar
Suraj Patil committed
94
95
96
97
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
98
            pad_token_id=1,
Suraj Patil's avatar
Suraj Patil committed
99
100
            vocab_size=5002,
        )
101
        text_encoder = CLIPTextModel(text_encoder_config)
Suraj Patil's avatar
Suraj Patil committed
102

103
104
        tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
        tokenizer.model_max_length = 77
Suraj Patil's avatar
Suraj Patil committed
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        components = {
            "unet": unet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
            "feature_extractor": None,
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 6.0,
            "output_type": "numpy",
        }
        return inputs
Suraj Patil's avatar
Suraj Patil committed
130

131
132
133
134
135
136
    def test_attention_slicing_forward_pass(self):
        super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)

    def test_inference_batch_single_identical(self):
        super().test_inference_batch_single_identical(expected_max_diff=3e-3)

Suraj Patil's avatar
Suraj Patil committed
137
138
139
    def test_alt_diffusion_ddim(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator

140
141
142
143
144
145
146
147
148
149
        components = self.get_dummy_components()
        torch.manual_seed(0)
        text_encoder_config = RobertaSeriesConfig(
            hidden_size=32,
            project_dim=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            vocab_size=5002,
Suraj Patil's avatar
Suraj Patil committed
150
        )
151
152
153
154
155
        # TODO: remove after fixing the non-deterministic text encoder
        text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
        components["text_encoder"] = text_encoder

        alt_pipe = AltDiffusionPipeline(**components)
Suraj Patil's avatar
Suraj Patil committed
156
157
158
        alt_pipe = alt_pipe.to(device)
        alt_pipe.set_progress_bar_config(disable=None)

159
160
161
        inputs = self.get_dummy_inputs(device)
        inputs["prompt"] = "A photo of an astronaut"
        output = alt_pipe(**inputs)
Suraj Patil's avatar
Suraj Patil committed
162
163
164
        image = output.images
        image_slice = image[0, -3:, -3:, -1]

165
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
166
167
168
        expected_slice = np.array(
            [0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
        )
Suraj Patil's avatar
Suraj Patil committed
169
170
171
172
173
174

        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    def test_alt_diffusion_pndm(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator

175
176
177
178
179
180
181
182
183
184
185
        components = self.get_dummy_components()
        components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
        torch.manual_seed(0)
        text_encoder_config = RobertaSeriesConfig(
            hidden_size=32,
            project_dim=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            vocab_size=5002,
Suraj Patil's avatar
Suraj Patil committed
186
        )
187
188
189
190
        # TODO: remove after fixing the non-deterministic text encoder
        text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
        components["text_encoder"] = text_encoder
        alt_pipe = AltDiffusionPipeline(**components)
Suraj Patil's avatar
Suraj Patil committed
191
192
193
        alt_pipe = alt_pipe.to(device)
        alt_pipe.set_progress_bar_config(disable=None)

194
195
        inputs = self.get_dummy_inputs(device)
        output = alt_pipe(**inputs)
Suraj Patil's avatar
Suraj Patil committed
196
197
198
        image = output.images
        image_slice = image[0, -3:, -3:, -1]

199
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
200
201
202
        expected_slice = np.array(
            [0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
        )
203

Suraj Patil's avatar
Suraj Patil committed
204
205
206
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


Dhruv Nair's avatar
Dhruv Nair committed
207
@nightly
Suraj Patil's avatar
Suraj Patil committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@require_torch_gpu
class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def test_alt_diffusion(self):
        # make sure here that pndm scheduler skips prk
        alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)
        alt_pipe = alt_pipe.to(torch_device)
        alt_pipe.set_progress_bar_config(disable=None)

        prompt = "A painting of a squirrel eating a burger"
223
224
        generator = torch.manual_seed(0)
        output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
Suraj Patil's avatar
Suraj Patil committed
225
226
227
228
229
230

        image = output.images

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
231
232
        expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])

Suraj Patil's avatar
Suraj Patil committed
233
234
235
236
237
238
239
240
241
242
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    def test_alt_diffusion_fast_ddim(self):
        scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")

        alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)
        alt_pipe = alt_pipe.to(torch_device)
        alt_pipe.set_progress_bar_config(disable=None)

        prompt = "A painting of a squirrel eating a burger"
243
        generator = torch.manual_seed(0)
Suraj Patil's avatar
Suraj Patil committed
244

245
        output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
Suraj Patil's avatar
Suraj Patil committed
246
247
248
249
250
        image = output.images

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
251
        expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
Suraj Patil's avatar
Suraj Patil committed
252

253
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2