test_models_unet_2d.py 10.8 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import gc
17
18
19
20
21
import math
import unittest

import torch

22
from diffusers import UNet2DModel
Dhruv Nair's avatar
Dhruv Nair committed
23
24
25
26
27
28
29
30
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    floats_tensor,
    slow,
    torch_all_close,
    torch_device,
)
31

32
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
33
34


Patrick von Platen's avatar
Patrick von Platen committed
35
logger = logging.get_logger(__name__)
36
37

enable_full_determinism()
38
39


40
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
41
    model_class = UNet2DModel
42
    main_input_name = "sample"
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "block_out_channels": (32, 64),
            "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
            "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
68
            "attention_head_dim": 3,
69
70
71
72
73
74
75
76
            "out_channels": 3,
            "in_channels": 3,
            "layers_per_block": 2,
            "sample_size": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def test_mid_block_attn_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["norm_num_groups"] = 16
        init_dict["add_attention"] = True
        init_dict["attn_norm_num_groups"] = 8

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        self.assertIsNotNone(
            model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
        )
        self.assertEqual(
            model.mid_block.attentions[0].group_norm.num_groups,
            init_dict["attn_norm_num_groups"],
            "Mid block Attention group norm does not have the expected number of groups.",
        )

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

107

108
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
109
    model_class = UNet2DModel
110
    main_input_name = "sample"
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (4, 32, 32)

    @property
    def output_shape(self):
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "sample_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "layers_per_block": 2,
            "block_out_channels": (32, 64),
            "attention_head_dim": 32,
            "down_block_types": ("DownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "UpBlock2D"),
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
        model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)

        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
152
        image = model(**self.dummy_input).sample
153
154
155

        assert image is not None, "Make sure output is not None"

Anton Lozhkov's avatar
Anton Lozhkov committed
156
    @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
157
    def test_from_pretrained_accelerate(self):
158
        model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
159
160
161
162
163
        model.to(torch_device)
        image = model(**self.dummy_input).sample

        assert image is not None, "Make sure output is not None"

Anton Lozhkov's avatar
Anton Lozhkov committed
164
    @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
165
    def test_from_pretrained_accelerate_wont_change_results(self):
166
        # by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
167
        model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        model_accelerate.to(torch_device)
        model_accelerate.eval()

        noise = torch.randn(
            1,
            model_accelerate.config.in_channels,
            model_accelerate.config.sample_size,
            model_accelerate.config.sample_size,
            generator=torch.manual_seed(0),
        )
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

        arr_accelerate = model_accelerate(noise, time_step)["sample"]

        # two models don't need to stay in the device at the same time
        del model_accelerate
        torch.cuda.empty_cache()
        gc.collect()

188
        model_normal_load, _ = UNet2DModel.from_pretrained(
189
            "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
190
        )
191
192
193
194
        model_normal_load.to(torch_device)
        model_normal_load.eval()
        arr_normal_load = model_normal_load(noise, time_step)["sample"]

195
        assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)
196

197
198
199
    def test_output_pretrained(self):
        model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
        model.eval()
200
        model.to(torch_device)
201

202
203
204
205
206
207
208
        noise = torch.randn(
            1,
            model.config.in_channels,
            model.config.sample_size,
            model.config.sample_size,
            generator=torch.manual_seed(0),
        )
209
210
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
211
212

        with torch.no_grad():
213
            output = model(noise, time_step).sample
214

215
        output_slice = output[0, -1, -3:, -3:].flatten().cpu()
216
217
218
219
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

220
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
221
222


223
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
224
    model_class = UNet2DModel
225
    main_input_name = "sample"
226
227
228
229
230
231
232

    @property
    def dummy_input(self, sizes=(32, 32)):
        batch_size = 4
        num_channels = 3

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
233
        time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "block_out_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "layers_per_block": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "norm_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "norm_num_groups": None,
            "down_block_types": [
                "SkipDownBlock2D",
                "AttnSkipDownBlock2D",
                "SkipDownBlock2D",
                "SkipDownBlock2D",
            ],
            "up_block_types": [
                "SkipUpBlock2D",
                "SkipUpBlock2D",
                "AttnSkipUpBlock2D",
                "SkipUpBlock2D",
            ],
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

271
    @slow
272
    def test_from_pretrained_hub(self):
273
        model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
274
275
276
277
278
279
280
281
282
283
284
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        inputs = self.dummy_input
        noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
        inputs["sample"] = noise
        image = model(**inputs)

        assert image is not None, "Make sure output is not None"

285
    @slow
286
    def test_output_pretrained_ve_mid(self):
287
        model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
288
289
290
291
292
293
294
295
296
297
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
298
            output = model(noise, time_step).sample
299
300
301

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
302
        expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
303
304
        # fmt: on

305
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
306
307
308
309
310
311
312
313
314
315
316
317
318

    def test_output_pretrained_ve_large(self):
        model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
319
            output = model(noise, time_step).sample
320
321
322
323
324
325

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
        # fmt: on

326
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
327
328
329
330

    def test_forward_with_norm_groups(self):
        # not required for this model
        pass