test_modeling_common.py 28.7 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import tempfile
18
import traceback
19
import unittest
20
import unittest.mock as mock
21
import uuid
22
from typing import Dict, List, Tuple
23
24

import numpy as np
25
import requests_mock
26
import torch
27
from huggingface_hub import delete_repo
28
from requests.exceptions import HTTPError
29

30
from diffusers.models import UNet2DConditionModel
31
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
32
from diffusers.training_utils import EMAModel
Dhruv Nair's avatar
Dhruv Nair committed
33
from diffusers.utils import is_xformers_available, logging
34
35
from diffusers.utils.testing_utils import (
    CaptureLogger,
Dhruv Nair's avatar
Dhruv Nair committed
36
    require_python39_or_higher,
37
    require_torch_2,
Arsalan's avatar
Arsalan committed
38
    require_torch_accelerator_with_training,
39
40
    require_torch_gpu,
    run_test_in_subprocess,
Dhruv Nair's avatar
Dhruv Nair committed
41
    torch_device,
42
43
44
)

from ..others.test_utils import TOKEN, USER, is_staging_test
45
46
47
48
49
50
51
52
53
54
55
56
57


# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
    error = None
    try:
        init_dict, model_class = in_queue.get(timeout=timeout)

        model = model_class(**init_dict)
        model.to(torch_device)
        model = torch.compile(model)

        with tempfile.TemporaryDirectory() as tmpdirname:
58
            model.save_pretrained(tmpdirname, safe_serialization=False)
59
60
61
62
63
64
65
66
67
68
            new_model = model_class.from_pretrained(tmpdirname)
            new_model.to(torch_device)

        assert new_model.__class__ == model_class
    except Exception:
        error = f"{traceback.format_exc()}"

    results = {"error": error}
    out_queue.put(results, timeout=timeout)
    out_queue.join()
69
70


71
class ModelUtilsTest(unittest.TestCase):
72
73
74
    def tearDown(self):
        super().tearDown()

75
76
77
78
79
80
81
    def test_accelerate_loading_error_message(self):
        with self.assertRaises(ValueError) as error_context:
            UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")

        # make sure that error message states what keys are missing
        assert "conv_out.bias" in str(error_context.exception)

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def test_cached_files_are_used_when_no_internet(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        orig_model = UNet2DConditionModel.from_pretrained(
            "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
        )

        # Under the mock environment we get a 500 error when trying to reach the model.
        with mock.patch("requests.request", return_value=response_mock):
            # Download this model to make sure it's in the cache.
            model = UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
            )

        for p1, p2 in zip(orig_model.parameters(), model.parameters()):
            if p1.data.ne(p2.data).sum() > 0:
                assert False, "Parameters not the same!"

106
107
108
109
110
    def test_one_request_upon_cached(self):
        # TODO: For some reason this test fails on MPS where no HEAD call is made.
        if torch_device == "mps":
            return

111
        use_safetensors = False
112
113
114
115

        with tempfile.TemporaryDirectory() as tmpdirname:
            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
116
117
118
119
                    "hf-internal-testing/tiny-stable-diffusion-torch",
                    subfolder="unet",
                    cache_dir=tmpdirname,
                    use_safetensors=use_safetensors,
120
121
122
123
124
125
126
127
                )

            download_requests = [r.method for r in m.request_history]
            assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
            assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"

            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
128
129
130
131
                    "hf-internal-testing/tiny-stable-diffusion-torch",
                    subfolder="unet",
                    cache_dir=tmpdirname,
                    use_safetensors=use_safetensors,
132
133
134
135
136
137
138
                )

            cache_requests = [r.method for r in m.request_history]
            assert (
                "HEAD" == cache_requests[0] and len(cache_requests) == 1
            ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    def test_weight_overwrite(self):
        with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
            UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch",
                subfolder="unet",
                cache_dir=tmpdirname,
                in_channels=9,
            )

        # make sure that error message states what keys are missing
        assert "Cannot load" in str(error_context.exception)

        with tempfile.TemporaryDirectory() as tmpdirname:
            model = UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch",
                subfolder="unet",
                cache_dir=tmpdirname,
                in_channels=9,
                low_cpu_mem_usage=False,
                ignore_mismatched_sizes=True,
            )

        assert model.config.in_channels == 9

163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class UNetTesterMixin:
    def test_forward_signature(self):
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

        expected_arg_names = ["sample", "timestep"]
        self.assertListEqual(arg_names[:2], expected_arg_names)

    def test_forward_with_norm_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["norm_num_groups"] = 16
        init_dict["block_out_channels"] = (16, 32)

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")


197
class ModelTesterMixin:
198
199
    main_input_name = None  # overwrite in model specific tester class
    base_precision = 1e-3
Will Berman's avatar
Will Berman committed
200
    forward_requires_fresh_args = False
201

202
    def test_from_save_pretrained(self, expected_max_diff=5e-5):
Will Berman's avatar
Will Berman committed
203
204
205
206
207
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
208

209
210
        if hasattr(model, "set_default_attn_processor"):
            model.set_default_attn_processor()
211
212
213
214
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
215
            model.save_pretrained(tmpdirname, safe_serialization=False)
216
            new_model = self.model_class.from_pretrained(tmpdirname)
217
218
            if hasattr(new_model, "set_default_attn_processor"):
                new_model.set_default_attn_processor()
219
220
221
            new_model.to(torch_device)

        with torch.no_grad():
Will Berman's avatar
Will Berman committed
222
223
224
225
226
            if self.forward_requires_fresh_args:
                image = model(**self.inputs_dict(0))
            else:
                image = model(**inputs_dict)

227
            if isinstance(image, dict):
228
                image = image.to_tuple()[0]
229

Will Berman's avatar
Will Berman committed
230
231
232
233
            if self.forward_requires_fresh_args:
                new_image = new_model(**self.inputs_dict(0))
            else:
                new_image = new_model(**inputs_dict)
234
235

            if isinstance(new_image, dict):
236
                new_image = new_image.to_tuple()[0]
237

238
239
        max_diff = (image - new_image).abs().max().item()
        self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    def test_getattr_is_correct(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)

        # save some things to test
        model.dummy_attribute = 5
        model.register_to_config(test_attribute=5)

        logger = logging.get_logger("diffusers.models.modeling_utils")
        # 30 for warning
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            assert hasattr(model, "dummy_attribute")
            assert getattr(model, "dummy_attribute") == 5
            assert model.dummy_attribute == 5

        # no warning should be thrown
        assert cap_logger.out == ""

        logger = logging.get_logger("diffusers.models.modeling_utils")
        # 30 for warning
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            assert hasattr(model, "save_pretrained")
            fn = model.save_pretrained
            fn_1 = getattr(model, "save_pretrained")

            assert fn == fn_1
        # no warning should be thrown
        assert cap_logger.out == ""

        # warning should be thrown
        with self.assertWarns(FutureWarning):
            assert model.test_attribute == 5

        with self.assertWarns(FutureWarning):
            assert getattr(model, "test_attribute") == 5

        with self.assertRaises(AttributeError) as error:
            model.does_not_exist

        assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"

Dhruv Nair's avatar
Dhruv Nair committed
284
285
286
287
288
289
    @unittest.skipIf(
        torch_device != "cuda" or not is_xformers_available(),
        reason="XFormers attention is only available with CUDA and `xformers` installed",
    )
    def test_set_xformers_attn_processor_for_determinism(self):
        torch.use_deterministic_algorithms(False)
Will Berman's avatar
Will Berman committed
290
291
292
293
294
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
Dhruv Nair's avatar
Dhruv Nair committed
295
296
297
298
299
300
301
302
303
        model.to(torch_device)

        if not hasattr(model, "set_attn_processor"):
            # If not has `set_attn_processor`, skip test
            return

        model.set_default_attn_processor()
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
304
305
306
307
            if self.forward_requires_fresh_args:
                output = model(**self.inputs_dict(0))[0]
            else:
                output = model(**inputs_dict)[0]
Dhruv Nair's avatar
Dhruv Nair committed
308
309
310
311

        model.enable_xformers_memory_efficient_attention()
        assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
312
313
314
315
            if self.forward_requires_fresh_args:
                output_2 = model(**self.inputs_dict(0))[0]
            else:
                output_2 = model(**inputs_dict)[0]
Dhruv Nair's avatar
Dhruv Nair committed
316

317
318
319
        model.set_attn_processor(XFormersAttnProcessor())
        assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
320
321
322
323
            if self.forward_requires_fresh_args:
                output_3 = model(**self.inputs_dict(0))[0]
            else:
                output_3 = model(**inputs_dict)[0]
324
325
326

        torch.use_deterministic_algorithms(True)

Dhruv Nair's avatar
Dhruv Nair committed
327
        assert torch.allclose(output, output_2, atol=self.base_precision)
328
329
        assert torch.allclose(output, output_3, atol=self.base_precision)
        assert torch.allclose(output_2, output_3, atol=self.base_precision)
Dhruv Nair's avatar
Dhruv Nair committed
330

331
332
333
    @require_torch_gpu
    def test_set_attn_processor_for_determinism(self):
        torch.use_deterministic_algorithms(False)
Will Berman's avatar
Will Berman committed
334
335
336
337
338
339
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)

340
341
342
343
344
345
346
347
        model.to(torch_device)

        if not hasattr(model, "set_attn_processor"):
            # If not has `set_attn_processor`, skip test
            return

        assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
348
349
350
351
            if self.forward_requires_fresh_args:
                output_1 = model(**self.inputs_dict(0))[0]
            else:
                output_1 = model(**inputs_dict)[0]
352
353
354
355

        model.set_default_attn_processor()
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
356
357
358
359
            if self.forward_requires_fresh_args:
                output_2 = model(**self.inputs_dict(0))[0]
            else:
                output_2 = model(**inputs_dict)[0]
360
361
362
363

        model.set_attn_processor(AttnProcessor2_0())
        assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
364
365
366
367
            if self.forward_requires_fresh_args:
                output_4 = model(**self.inputs_dict(0))[0]
            else:
                output_4 = model(**inputs_dict)[0]
368
369
370
371

        model.set_attn_processor(AttnProcessor())
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
372
373
374
375
            if self.forward_requires_fresh_args:
                output_5 = model(**self.inputs_dict(0))[0]
            else:
                output_5 = model(**inputs_dict)[0]
376
377
378
379
380
381
382
383

        torch.use_deterministic_algorithms(True)

        # make sure that outputs match
        assert torch.allclose(output_2, output_1, atol=self.base_precision)
        assert torch.allclose(output_2, output_4, atol=self.base_precision)
        assert torch.allclose(output_2, output_5, atol=self.base_precision)

384
    def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
Will Berman's avatar
Will Berman committed
385
386
387
388
389
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
390

391
392
        if hasattr(model, "set_default_attn_processor"):
            model.set_default_attn_processor()
393

394
395
396
397
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
398
            model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
399
            new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
400
401
            if hasattr(new_model, "set_default_attn_processor"):
                new_model.set_default_attn_processor()
402
403
404
405
406
407
408
409
410
411
412

            # non-variant cannot be loaded
            with self.assertRaises(OSError) as error_context:
                self.model_class.from_pretrained(tmpdirname)

            # make sure that error message states what keys are missing
            assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)

            new_model.to(torch_device)

        with torch.no_grad():
Will Berman's avatar
Will Berman committed
413
414
415
416
            if self.forward_requires_fresh_args:
                image = model(**self.inputs_dict(0))
            else:
                image = model(**inputs_dict)
417
            if isinstance(image, dict):
418
                image = image.to_tuple()[0]
419

Will Berman's avatar
Will Berman committed
420
421
422
423
            if self.forward_requires_fresh_args:
                new_image = new_model(**self.inputs_dict(0))
            else:
                new_image = new_model(**inputs_dict)
424
425

            if isinstance(new_image, dict):
426
                new_image = new_image.to_tuple()[0]
427

428
429
        max_diff = (image - new_image).abs().max().item()
        self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
430

Dhruv Nair's avatar
Dhruv Nair committed
431
    @require_python39_or_higher
432
    @require_torch_2
433
    def test_from_save_pretrained_dynamo(self):
434
435
436
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()
        inputs = [init_dict, self.model_class]
        run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs)
437

438
439
440
441
442
443
444
445
446
447
448
449
    def test_from_save_pretrained_dtype(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
            if torch_device == "mps" and dtype == torch.bfloat16:
                continue
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.to(dtype)
450
                model.save_pretrained(tmpdirname, safe_serialization=False)
451
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
452
                assert new_model.dtype == dtype
453
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
454
455
                assert new_model.dtype == dtype

456
    def test_determinism(self, expected_max_diff=1e-5):
Will Berman's avatar
Will Berman committed
457
458
459
460
461
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
462
463
        model.to(torch_device)
        model.eval()
464

465
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
466
467
468
469
            if self.forward_requires_fresh_args:
                first = model(**self.inputs_dict(0))
            else:
                first = model(**inputs_dict)
470
            if isinstance(first, dict):
471
                first = first.to_tuple()[0]
472

Will Berman's avatar
Will Berman committed
473
474
475
476
            if self.forward_requires_fresh_args:
                second = model(**self.inputs_dict(0))
            else:
                second = model(**inputs_dict)
477
            if isinstance(second, dict):
478
                second = second.to_tuple()[0]
479
480
481
482
483
484

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
485
        self.assertLessEqual(max_diff, expected_max_diff)
486
487
488
489
490
491
492
493
494
495
496

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
497
                output = output.to_tuple()[0]
498
499

        self.assertIsNotNone(output)
500

501
502
503
        # input & output have to have the same shape
        input_tensor = inputs_dict[self.main_input_name]
        expected_shape = input_tensor.shape
504
505
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

506
    def test_model_from_pretrained(self):
507
508
509
510
511
512
513
514
515
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
516
            model.save_pretrained(tmpdirname, safe_serialization=False)
517
            new_model = self.model_class.from_pretrained(tmpdirname)
518
519
520
            new_model.to(torch_device)
            new_model.eval()

521
        # check if all parameters shape are the same
522
523
524
525
526
527
528
529
530
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)

        with torch.no_grad():
            output_1 = model(**inputs_dict)

            if isinstance(output_1, dict):
531
                output_1 = output_1.to_tuple()[0]
532
533
534
535

            output_2 = new_model(**inputs_dict)

            if isinstance(output_2, dict):
536
                output_2 = output_2.to_tuple()[0]
537
538
539

        self.assertEqual(output_1.shape, output_2.shape)

Arsalan's avatar
Arsalan committed
540
    @require_torch_accelerator_with_training
541
542
543
544
545
546
547
548
549
    def test_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)

        if isinstance(output, dict):
550
            output = output.to_tuple()[0]
551

552
553
        input_tensor = inputs_dict[self.main_input_name]
        noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
554
555
556
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()

Arsalan's avatar
Arsalan committed
557
    @require_torch_accelerator_with_training
558
559
560
561
562
563
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
564
        ema_model = EMAModel(model.parameters())
565
566
567
568

        output = model(**inputs_dict)

        if isinstance(output, dict):
569
            output = output.to_tuple()[0]
570

571
572
        input_tensor = inputs_dict[self.main_input_name]
        noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
573
574
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
575
        ema_model.step(model.parameters())
576

577
    def test_outputs_equivalence(self):
578
        def set_nan_tensor_to_zero(t):
579
580
581
582
583
            # Temporary fallback until `aten::_index_put_impl_` is implemented in mps
            # Track progress in https://github.com/pytorch/pytorch/issues/77764
            device = t.device
            if device.type == "mps":
                t = t.to("cpu")
584
            t[t != t] = 0
585
            return t.to(device)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

        def recursive_check(tuple_object, dict_object):
            if isinstance(tuple_object, (List, Tuple)):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif isinstance(tuple_object, Dict):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif tuple_object is None:
                return
            else:
                self.assertTrue(
                    torch.allclose(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                    ),
                    msg=(
                        "Tuple and dict output are not equal. Difference:"
                        f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                        f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                        f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                    ),
                )

Will Berman's avatar
Will Berman committed
609
610
611
612
613
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
614
615
616
617

        model.to(torch_device)
        model.eval()

618
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
619
620
621
622
623
624
            if self.forward_requires_fresh_args:
                outputs_dict = model(**self.inputs_dict(0))
                outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
            else:
                outputs_dict = model(**inputs_dict)
                outputs_tuple = model(**inputs_dict, return_dict=False)
625
626

        recursive_check(outputs_tuple, outputs_dict)
627

Arsalan's avatar
Arsalan committed
628
    @require_torch_accelerator_with_training
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    def test_enable_disable_gradient_checkpointing(self):
        if not self.model_class._supports_gradient_checkpointing:
            return  # Skip test if model does not support gradient checkpointing

        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        # at init model should have gradient checkpointing disabled
        model = self.model_class(**init_dict)
        self.assertFalse(model.is_gradient_checkpointing)

        # check enable works
        model.enable_gradient_checkpointing()
        self.assertTrue(model.is_gradient_checkpointing)

        # check disable works
        model.disable_gradient_checkpointing()
        self.assertFalse(model.is_gradient_checkpointing)
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

    def test_deprecated_kwargs(self):
        has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
        has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

        if has_kwarg_in_model_class and not has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
                " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
                " [<deprecated_argument>]`"
            )

        if not has_kwarg_in_model_class and has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
                f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
                " from `_deprecated_kwargs = [<deprecated_argument>]`"
            )
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
    identifier = uuid.uuid4()
    repo_id = f"test-model-{identifier}"
    org_repo_id = f"valid_org/{repo_id}-org"

    def test_push_to_hub(self):
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        model.push_to_hub(self.repo_id, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(token=TOKEN, repo_id=self.repo_id)

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(self.repo_id, token=TOKEN)

    def test_push_to_hub_in_organization(self):
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        model.push_to_hub(self.org_repo_id, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(token=TOKEN, repo_id=self.org_repo_id)

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)

        new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(self.org_repo_id, token=TOKEN)