test_models_unet_2d.py 13.8 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2024 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import gc
17
18
19
20
21
import math
import unittest

import torch

22
from diffusers import UNet2DModel
Dhruv Nair's avatar
Dhruv Nair committed
23
24
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
25
    backend_empty_cache,
Dhruv Nair's avatar
Dhruv Nair committed
26
27
    enable_full_determinism,
    floats_tensor,
Arsalan's avatar
Arsalan committed
28
    require_torch_accelerator,
Dhruv Nair's avatar
Dhruv Nair committed
29
30
31
32
    slow,
    torch_all_close,
    torch_device,
)
33

34
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
35
36


Patrick von Platen's avatar
Patrick von Platen committed
37
logger = logging.get_logger(__name__)
38
39

enable_full_determinism()
40
41


42
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
43
    model_class = UNet2DModel
44
    main_input_name = "sample"
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
67
68
            "block_out_channels": (4, 8),
            "norm_num_groups": 2,
69
70
            "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
            "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
71
            "attention_head_dim": 3,
72
73
74
75
76
77
78
79
            "out_channels": 3,
            "in_channels": 3,
            "layers_per_block": 2,
            "sample_size": 32,
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

80
81
82
83
    def test_mid_block_attn_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["add_attention"] = True
84
        init_dict["attn_norm_num_groups"] = 4
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        self.assertIsNotNone(
            model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
        )
        self.assertEqual(
            model.mid_block.attentions[0].group_norm.num_groups,
            init_dict["attn_norm_num_groups"],
            "Mid block Attention group norm does not have the expected number of groups.",
        )

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

hlky's avatar
hlky committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    def test_mid_block_none(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
        mid_none_init_dict["mid_block_type"] = None

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        mid_none_model = self.model_class(**mid_none_init_dict)
        mid_none_model.to(torch_device)
        mid_none_model.eval()

        self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        with torch.no_grad():
            mid_none_output = mid_none_model(**mid_none_inputs_dict)

            if isinstance(mid_none_output, dict):
                mid_none_output = mid_none_output.to_tuple()[0]

        self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def test_gradient_checkpointing_is_applied(self):
        expected_set = {
            "AttnUpBlock2D",
            "AttnDownBlock2D",
            "UNetMidBlock2D",
            "UpBlock2D",
            "DownBlock2D",
        }

        # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
        attention_head_dim = 8
        block_out_channels = (16, 32)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
        )

155

156
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
157
    model_class = UNet2DModel
158
    main_input_name = "sample"
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    @property
    def dummy_input(self):
        batch_size = 4
        num_channels = 4
        sizes = (32, 32)

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor([10]).to(torch_device)

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (4, 32, 32)

    @property
    def output_shape(self):
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "sample_size": 32,
            "in_channels": 4,
            "out_channels": 4,
            "layers_per_block": 2,
            "block_out_channels": (32, 64),
            "attention_head_dim": 32,
            "down_block_types": ("DownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "UpBlock2D"),
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def test_from_pretrained_hub(self):
        model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)

        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
200
        image = model(**self.dummy_input).sample
201
202
203

        assert image is not None, "Make sure output is not None"

Arsalan's avatar
Arsalan committed
204
    @require_torch_accelerator
205
    def test_from_pretrained_accelerate(self):
206
        model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
207
208
209
210
211
        model.to(torch_device)
        image = model(**self.dummy_input).sample

        assert image is not None, "Make sure output is not None"

Arsalan's avatar
Arsalan committed
212
    @require_torch_accelerator
213
    def test_from_pretrained_accelerate_wont_change_results(self):
214
        # by default model loading will use accelerate as `low_cpu_mem_usage=True`
215
        model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        model_accelerate.to(torch_device)
        model_accelerate.eval()

        noise = torch.randn(
            1,
            model_accelerate.config.in_channels,
            model_accelerate.config.sample_size,
            model_accelerate.config.sample_size,
            generator=torch.manual_seed(0),
        )
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

        arr_accelerate = model_accelerate(noise, time_step)["sample"]

        # two models don't need to stay in the device at the same time
        del model_accelerate
233
        backend_empty_cache(torch_device)
234
235
        gc.collect()

236
        model_normal_load, _ = UNet2DModel.from_pretrained(
237
            "fusing/unet-ldm-dummy-update", output_loading_info=True, low_cpu_mem_usage=False
238
        )
239
240
241
242
        model_normal_load.to(torch_device)
        model_normal_load.eval()
        arr_normal_load = model_normal_load(noise, time_step)["sample"]

243
        assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)
244

245
246
247
    def test_output_pretrained(self):
        model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
        model.eval()
248
        model.to(torch_device)
249

250
251
252
253
254
255
256
        noise = torch.randn(
            1,
            model.config.in_channels,
            model.config.sample_size,
            model.config.sample_size,
            generator=torch.manual_seed(0),
        )
257
258
        noise = noise.to(torch_device)
        time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
259
260

        with torch.no_grad():
261
            output = model(noise, time_step).sample
262

263
        output_slice = output[0, -1, -3:, -3:].flatten().cpu()
264
265
266
267
        # fmt: off
        expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
        # fmt: on

268
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
269

270
271
272
273
274
275
276
277
278
279
280
    def test_gradient_checkpointing_is_applied(self):
        expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}

        # NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
        attention_head_dim = 32
        block_out_channels = (32, 64)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
        )

281

282
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
283
    model_class = UNet2DModel
284
    main_input_name = "sample"
285
286
287
288
289
290
291

    @property
    def dummy_input(self, sizes=(32, 32)):
        batch_size = 4
        num_channels = 3

        noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
292
        time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

        return {"sample": noise, "timestep": time_step}

    @property
    def input_shape(self):
        return (3, 32, 32)

    @property
    def output_shape(self):
        return (3, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "block_out_channels": [32, 64, 64, 64],
            "in_channels": 3,
            "layers_per_block": 1,
            "out_channels": 3,
            "time_embedding_type": "fourier",
            "norm_eps": 1e-6,
            "mid_block_scale_factor": math.sqrt(2.0),
            "norm_num_groups": None,
            "down_block_types": [
                "SkipDownBlock2D",
                "AttnSkipDownBlock2D",
                "SkipDownBlock2D",
                "SkipDownBlock2D",
            ],
            "up_block_types": [
                "SkipUpBlock2D",
                "SkipUpBlock2D",
                "AttnSkipUpBlock2D",
                "SkipUpBlock2D",
            ],
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

330
    @slow
331
    def test_from_pretrained_hub(self):
332
        model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
333
334
335
336
337
338
339
340
341
342
343
        self.assertIsNotNone(model)
        self.assertEqual(len(loading_info["missing_keys"]), 0)

        model.to(torch_device)
        inputs = self.dummy_input
        noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
        inputs["sample"] = noise
        image = model(**inputs)

        assert image is not None, "Make sure output is not None"

344
    @slow
345
    def test_output_pretrained_ve_mid(self):
346
        model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
347
348
349
350
351
352
353
354
355
356
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (256, 256)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
357
            output = model(noise, time_step).sample
358
359
360

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
361
        expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
362
363
        # fmt: on

364
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
365
366
367
368
369
370
371
372
373
374
375
376
377

    def test_output_pretrained_ve_large(self):
        model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
        model.to(torch_device)

        batch_size = 4
        num_channels = 3
        sizes = (32, 32)

        noise = torch.ones((batch_size, num_channels) + sizes).to(torch_device)
        time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)

        with torch.no_grad():
378
            output = model(noise, time_step).sample
379
380
381
382
383
384

        output_slice = output[0, -3:, -3:, -1].flatten().cpu()
        # fmt: off
        expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
        # fmt: on

385
        self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
386

387
    @unittest.skip("Test not supported.")
388
389
390
    def test_forward_with_norm_groups(self):
        # not required for this model
        pass
391
392
393
394
395
396
397
398
399
400
401
402
403
404

    def test_gradient_checkpointing_is_applied(self):
        expected_set = {
            "UNetMidBlock2D",
        }

        block_out_channels = (32, 64, 64, 64)

        super().test_gradient_checkpointing_is_applied(
            expected_set=expected_set, block_out_channels=block_out_channels
        )

    def test_effective_gradient_checkpointing(self):
        super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
Aryan's avatar
Aryan committed
405
406
407
408
409
410
411
412
413
414
415
416

    @unittest.skip(
        "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
    )
    def test_layerwise_casting_inference(self):
        pass

    @unittest.skip(
        "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
    )
    def test_layerwise_casting_memory(self):
        pass