test_models_unet_2d.py 13.3 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2024 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import gc
17
18
19
20
21
import math
import unittest

import torch

22
from diffusers import UNet2DModel
Dhruv Nair's avatar
Dhruv Nair committed
23
24
25
26
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    floats_tensor,
Arsalan's avatar
Arsalan committed
27
    require_torch_accelerator,
Dhruv Nair's avatar
Dhruv Nair committed
28
29
30
31
    slow,
    torch_all_close,
    torch_device,
)
32

33
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
34
35


Patrick von Platen's avatar
Patrick von Platen committed
36
logger = logging.get_logger(__name__)
37
38

enable_full_determinism()
39
40


41
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
42
    model_class = UNet2DModel
43
    main_input_name = "sample"
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
66
67
            "block_out_channels": (4, 8),
            "norm_num_groups": 2,
68
69
            "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
            "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
70
            "attention_head_dim": 3,
71
72
73
74
75
76
77
78
            "out_channels": 3,
            "in_channels": 3,
            "layers_per_block": 2,
            "sample_size": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

79
80
81
82
    def test_mid_block_attn_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["add_attention"] = True
83
        init_dict["attn_norm_num_groups"] = 4
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        self.assertIsNotNone(
            model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
        )
        self.assertEqual(
            model.mid_block.attentions[0].group_norm.num_groups,
            init_dict["attn_norm_num_groups"],
            "Mid block Attention group norm does not have the expected number of groups.",
        )

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

hlky's avatar
hlky committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def test_mid_block_none(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
        mid_none_init_dict["mid_block_type"] = None

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        mid_none_model = self.model_class(**mid_none_init_dict)
        mid_none_model.to(torch_device)
        mid_none_model.eval()

        self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        with torch.no_grad():
            mid_none_output = mid_none_model(**mid_none_inputs_dict)

            if isinstance(mid_none_output, dict):
                mid_none_output = mid_none_output.to_tuple()[0]

        self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def test_gradient_checkpointing_is_applied(self):
        expected_set = {
            "AttnUpBlock2D",
            "AttnDownBlock2D",
            "UNetMidBlock2D",
            "UpBlock2D",
            "DownBlock2D",
        }

        # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
        attention_head_dim = 8
        block_out_channels = (16, 32)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
        )

154

155
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
156
    model_class = UNet2DModel
157
    main_input_name = "sample"
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (4, 32, 32)

    @property
    def output_shape(self):
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "sample_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "layers_per_block": 2,
            "block_out_channels": (32, 64),
            "attention_head_dim": 32,
            "down_block_types": ("DownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "UpBlock2D"),
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
        model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)

        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
199
        image = model(**self.dummy_input).sample
200
201
202

        assert image is not None, "Make sure output is not None"

Arsalan's avatar
Arsalan committed
203
    @require_torch_accelerator
204
    def test_from_pretrained_accelerate(self):
205
        model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
206
207
208
209
210
        model.to(torch_device)
        image = model(**self.dummy_input).sample

        assert image is not None, "Make sure output is not None"

Arsalan's avatar
Arsalan committed
211
    @require_torch_accelerator
212
    def test_from_pretrained_accelerate_wont_change_results(self):
213
        # by default model loading will use accelerate as `low_cpu_mem_usage=True`
214
        model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        model_accelerate.to(torch_device)
        model_accelerate.eval()

        noise = torch.randn(
            1,
            model_accelerate.config.in_channels,
            model_accelerate.config.sample_size,
            model_accelerate.config.sample_size,
            generator=torch.manual_seed(0),
        )
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

        arr_accelerate = model_accelerate(noise, time_step)["sample"]

        # two models don't need to stay in the device at the same time
        del model_accelerate
        torch.cuda.empty_cache()
        gc.collect()

235
        model_normal_load, _ = UNet2DModel.from_pretrained(
236
            "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
237
        )
238
239
240
241
        model_normal_load.to(torch_device)
        model_normal_load.eval()
        arr_normal_load = model_normal_load(noise, time_step)["sample"]

242
        assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)
243

244
245
246
    def test_output_pretrained(self):
        model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
        model.eval()
247
        model.to(torch_device)
248

249
250
251
252
253
254
255
        noise = torch.randn(
            1,
            model.config.in_channels,
            model.config.sample_size,
            model.config.sample_size,
            generator=torch.manual_seed(0),
        )
256
257
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
258
259

        with torch.no_grad():
260
            output = model(noise, time_step).sample
261

262
        output_slice = output[0, -1, -3:, -3:].flatten().cpu()
263
264
265
266
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

267
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
268

269
270
271
272
273
274
275
276
277
278
279
    def test_gradient_checkpointing_is_applied(self):
        expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}

        # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
        attention_head_dim = 32
        block_out_channels = (32, 64)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
        )

280

281
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
282
    model_class = UNet2DModel
283
    main_input_name = "sample"
284
285
286
287
288
289
290

    @property
    def dummy_input(self, sizes=(32, 32)):
        batch_size = 4
        num_channels = 3

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
291
        time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "block_out_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "layers_per_block": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "norm_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "norm_num_groups": None,
            "down_block_types": [
                "SkipDownBlock2D",
                "AttnSkipDownBlock2D",
                "SkipDownBlock2D",
                "SkipDownBlock2D",
            ],
            "up_block_types": [
                "SkipUpBlock2D",
                "SkipUpBlock2D",
                "AttnSkipUpBlock2D",
                "SkipUpBlock2D",
            ],
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

329
    @slow
330
    def test_from_pretrained_hub(self):
331
        model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
332
333
334
335
336
337
338
339
340
341
342
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        inputs = self.dummy_input
        noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
        inputs["sample"] = noise
        image = model(**inputs)

        assert image is not None, "Make sure output is not None"

343
    @slow
344
    def test_output_pretrained_ve_mid(self):
345
        model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
346
347
348
349
350
351
352
353
354
355
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
356
            output = model(noise, time_step).sample
357
358
359

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
360
        expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
361
362
        # fmt: on

363
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
364
365
366
367
368
369
370
371
372
373
374
375
376

    def test_output_pretrained_ve_large(self):
        model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
377
            output = model(noise, time_step).sample
378
379
380
381
382
383

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
        # fmt: on

384
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
385

386
    @unittest.skip("Test not supported.")
387
388
389
    def test_forward_with_norm_groups(self):
        # not required for this model
        pass
390
391
392
393
394
395
396
397
398
399
400
401
402
403

    def test_gradient_checkpointing_is_applied(self):
        expected_set = {
            "UNetMidBlock2D",
        }

        block_out_channels = (32, 64, 64, 64)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, block_out_channels=block_out_channels
        )

    def test_effective_gradient_checkpointing(self):
        super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})