"vscode:/vscode.git/clone" did not exist on "f56aa20014efeb383af5380d3de35475d1f08c36"
test_alt_diffusion.py 8.87 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
Suraj Patil's avatar
Suraj Patil committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import numpy as np
import torch
21
from transformers import CLIPTextConfig, CLIPTextModel, XLMRobertaTokenizer
Suraj Patil's avatar
Suraj Patil committed
22
23
24
25
26
27

from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
    RobertaSeriesConfig,
    RobertaSeriesModelWithTransformation,
)
28
from diffusers.utils import slow, torch_device
Suraj Patil's avatar
Suraj Patil committed
29
30
from diffusers.utils.testing_utils import require_torch_gpu

31
32
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
Suraj Patil's avatar
Suraj Patil committed
33
34
35
36
37


torch.backends.cuda.matmul.allow_tf32 = False


38
class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
39
    pipeline_class = AltDiffusionPipeline
40
41
    params = TEXT_TO_IMAGE_PARAMS
    batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
42
    image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
Suraj Patil's avatar
Suraj Patil committed
43

44
    def get_dummy_components(self):
Suraj Patil's avatar
Suraj Patil committed
45
        torch.manual_seed(0)
46
        unet = UNet2DConditionModel(
Suraj Patil's avatar
Suraj Patil committed
47
48
49
50
51
52
53
54
55
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
56
57
58
59
60
61
        scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
Suraj Patil's avatar
Suraj Patil committed
62
63
        )
        torch.manual_seed(0)
64
        vae = AutoencoderKL(
Suraj Patil's avatar
Suraj Patil committed
65
66
67
68
69
70
71
72
            block_out_channels=[32, 64],
            in_channels=3,
            out_channels=3,
            down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
            up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
            latent_channels=4,
        )

73
74
75
76
77
78
79
80
81
82
83
84
85
        # TODO: address the non-deterministic text encoder (fails for save-load tests)
        # torch.manual_seed(0)
        # text_encoder_config = RobertaSeriesConfig(
        #     hidden_size=32,
        #     project_dim=32,
        #     intermediate_size=37,
        #     layer_norm_eps=1e-05,
        #     num_attention_heads=4,
        #     num_hidden_layers=5,
        #     vocab_size=5002,
        # )
        # text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)

Suraj Patil's avatar
Suraj Patil committed
86
        torch.manual_seed(0)
87
88
89
        text_encoder_config = CLIPTextConfig(
            bos_token_id=0,
            eos_token_id=2,
Suraj Patil's avatar
Suraj Patil committed
90
            hidden_size=32,
91
            projection_dim=32,
Suraj Patil's avatar
Suraj Patil committed
92
93
94
95
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
96
            pad_token_id=1,
Suraj Patil's avatar
Suraj Patil committed
97
98
            vocab_size=5002,
        )
99
        text_encoder = CLIPTextModel(text_encoder_config)
Suraj Patil's avatar
Suraj Patil committed
100

101
102
        tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
        tokenizer.model_max_length = 77
Suraj Patil's avatar
Suraj Patil committed
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        components = {
            "unet": unet,
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
            "safety_checker": None,
            "feature_extractor": None,
        }
        return components

    def get_dummy_inputs(self, device, seed=0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "generator": generator,
            "num_inference_steps": 2,
            "guidance_scale": 6.0,
            "output_type": "numpy",
        }
        return inputs
Suraj Patil's avatar
Suraj Patil committed
128
129
130
131

    def test_alt_diffusion_ddim(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator

132
133
134
135
136
137
138
139
140
141
        components = self.get_dummy_components()
        torch.manual_seed(0)
        text_encoder_config = RobertaSeriesConfig(
            hidden_size=32,
            project_dim=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            vocab_size=5002,
Suraj Patil's avatar
Suraj Patil committed
142
        )
143
144
145
146
147
        # TODO: remove after fixing the non-deterministic text encoder
        text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
        components["text_encoder"] = text_encoder

        alt_pipe = AltDiffusionPipeline(**components)
Suraj Patil's avatar
Suraj Patil committed
148
149
150
        alt_pipe = alt_pipe.to(device)
        alt_pipe.set_progress_bar_config(disable=None)

151
152
153
        inputs = self.get_dummy_inputs(device)
        inputs["prompt"] = "A photo of an astronaut"
        output = alt_pipe(**inputs)
Suraj Patil's avatar
Suraj Patil committed
154
155
156
        image = output.images
        image_slice = image[0, -3:, -3:, -1]

157
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
158
159
160
        expected_slice = np.array(
            [0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
        )
Suraj Patil's avatar
Suraj Patil committed
161
162
163
164
165
166

        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    def test_alt_diffusion_pndm(self):
        device = "cpu"  # ensure determinism for the device-dependent torch.Generator

167
168
169
170
171
172
173
174
175
176
177
        components = self.get_dummy_components()
        components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
        torch.manual_seed(0)
        text_encoder_config = RobertaSeriesConfig(
            hidden_size=32,
            project_dim=32,
            intermediate_size=37,
            layer_norm_eps=1e-05,
            num_attention_heads=4,
            num_hidden_layers=5,
            vocab_size=5002,
Suraj Patil's avatar
Suraj Patil committed
178
        )
179
180
181
182
        # TODO: remove after fixing the non-deterministic text encoder
        text_encoder = RobertaSeriesModelWithTransformation(text_encoder_config)
        components["text_encoder"] = text_encoder
        alt_pipe = AltDiffusionPipeline(**components)
Suraj Patil's avatar
Suraj Patil committed
183
184
185
        alt_pipe = alt_pipe.to(device)
        alt_pipe.set_progress_bar_config(disable=None)

186
187
        inputs = self.get_dummy_inputs(device)
        output = alt_pipe(**inputs)
Suraj Patil's avatar
Suraj Patil committed
188
189
190
        image = output.images
        image_slice = image[0, -3:, -3:, -1]

191
        assert image.shape == (1, 64, 64, 3)
Patrick von Platen's avatar
Patrick von Platen committed
192
193
194
        expected_slice = np.array(
            [0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
        )
195

Suraj Patil's avatar
Suraj Patil committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


@slow
@require_torch_gpu
class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
    def tearDown(self):
        # clean up the VRAM after each test
        super().tearDown()
        gc.collect()
        torch.cuda.empty_cache()

    def test_alt_diffusion(self):
        # make sure here that pndm scheduler skips prk
        alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)
        alt_pipe = alt_pipe.to(torch_device)
        alt_pipe.set_progress_bar_config(disable=None)

        prompt = "A painting of a squirrel eating a burger"
215
216
        generator = torch.manual_seed(0)
        output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
Suraj Patil's avatar
Suraj Patil committed
217
218
219
220
221
222

        image = output.images

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
223
224
        expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])

Suraj Patil's avatar
Suraj Patil committed
225
226
227
228
229
230
231
232
233
234
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

    def test_alt_diffusion_fast_ddim(self):
        scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")

        alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)
        alt_pipe = alt_pipe.to(torch_device)
        alt_pipe.set_progress_bar_config(disable=None)

        prompt = "A painting of a squirrel eating a burger"
235
        generator = torch.manual_seed(0)
Suraj Patil's avatar
Suraj Patil committed
236

237
        output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
Suraj Patil's avatar
Suraj Patil committed
238
239
240
241
242
        image = output.images

        image_slice = image[0, -3:, -3:, -1]

        assert image.shape == (1, 512, 512, 3)
243
        expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
Suraj Patil's avatar
Suraj Patil committed
244

245
        assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2