test_modeling_common.py 37.9 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2024 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import tempfile
18
import traceback
19
import unittest
20
import unittest.mock as mock
21
import uuid
22
from typing import Dict, List, Tuple
23
24

import numpy as np
25
import requests_mock
26
import torch
27
from accelerate.utils import compute_module_sizes
28
29
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
30
from requests.exceptions import HTTPError
31

32
from diffusers.models import UNet2DConditionModel
33
34
35
36
37
38
from diffusers.models.attention_processor import (
    AttnProcessor,
    AttnProcessor2_0,
    AttnProcessorNPU,
    XFormersAttnProcessor,
)
39
from diffusers.training_utils import EMAModel
40
from diffusers.utils import is_torch_npu_available, is_xformers_available, logging
41
42
from diffusers.utils.testing_utils import (
    CaptureLogger,
43
    get_python_version,
Dhruv Nair's avatar
Dhruv Nair committed
44
    require_python39_or_higher,
45
    require_torch_2,
Arsalan's avatar
Arsalan committed
46
    require_torch_accelerator_with_training,
47
    require_torch_gpu,
48
    require_torch_multi_gpu,
49
    run_test_in_subprocess,
Dhruv Nair's avatar
Dhruv Nair committed
50
    torch_device,
51
52
53
)

from ..others.test_utils import TOKEN, USER, is_staging_test
54
55
56
57
58
59
60
61
62
63
64
65
66


# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
    error = None
    try:
        init_dict, model_class = in_queue.get(timeout=timeout)

        model = model_class(**init_dict)
        model.to(torch_device)
        model = torch.compile(model)

        with tempfile.TemporaryDirectory() as tmpdirname:
67
            model.save_pretrained(tmpdirname, safe_serialization=False)
68
69
70
71
72
73
74
75
76
77
            new_model = model_class.from_pretrained(tmpdirname)
            new_model.to(torch_device)

        assert new_model.__class__ == model_class
    except Exception:
        error = f"{traceback.format_exc()}"

    results = {"error": error}
    out_queue.put(results, timeout=timeout)
    out_queue.join()
78
79


80
class ModelUtilsTest(unittest.TestCase):
81
82
83
    def tearDown(self):
        super().tearDown()

84
85
86
87
88
89
90
    def test_accelerate_loading_error_message(self):
        with self.assertRaises(ValueError) as error_context:
            UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")

        # make sure that error message states what keys are missing
        assert "conv_out.bias" in str(error_context.exception)

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def test_cached_files_are_used_when_no_internet(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        orig_model = UNet2DConditionModel.from_pretrained(
            "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
        )

        # Under the mock environment we get a 500 error when trying to reach the model.
        with mock.patch("requests.request", return_value=response_mock):
            # Download this model to make sure it's in the cache.
            model = UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
            )

        for p1, p2 in zip(orig_model.parameters(), model.parameters()):
            if p1.data.ne(p2.data).sum() > 0:
                assert False, "Parameters not the same!"

115
116
117
118
119
    def test_one_request_upon_cached(self):
        # TODO: For some reason this test fails on MPS where no HEAD call is made.
        if torch_device == "mps":
            return

120
        use_safetensors = False
121
122
123
124

        with tempfile.TemporaryDirectory() as tmpdirname:
            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
125
126
127
128
                    "hf-internal-testing/tiny-stable-diffusion-torch",
                    subfolder="unet",
                    cache_dir=tmpdirname,
                    use_safetensors=use_safetensors,
129
130
131
132
133
134
135
136
                )

            download_requests = [r.method for r in m.request_history]
            assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
            assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"

            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
137
138
139
140
                    "hf-internal-testing/tiny-stable-diffusion-torch",
                    subfolder="unet",
                    cache_dir=tmpdirname,
                    use_safetensors=use_safetensors,
141
142
143
144
145
146
147
                )

            cache_requests = [r.method for r in m.request_history]
            assert (
                "HEAD" == cache_requests[0] and len(cache_requests) == 1
            ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def test_weight_overwrite(self):
        with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
            UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch",
                subfolder="unet",
                cache_dir=tmpdirname,
                in_channels=9,
            )

        # make sure that error message states what keys are missing
        assert "Cannot load" in str(error_context.exception)

        with tempfile.TemporaryDirectory() as tmpdirname:
            model = UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch",
                subfolder="unet",
                cache_dir=tmpdirname,
                in_channels=9,
                low_cpu_mem_usage=False,
                ignore_mismatched_sizes=True,
            )

        assert model.config.in_channels == 9

172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class UNetTesterMixin:
    def test_forward_signature(self):
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

        expected_arg_names = ["sample", "timestep"]
        self.assertListEqual(arg_names[:2], expected_arg_names)

    def test_forward_with_norm_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["norm_num_groups"] = 16
        init_dict["block_out_channels"] = (16, 32)

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.to_tuple()[0]

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")


206
class ModelTesterMixin:
207
208
    main_input_name = None  # overwrite in model specific tester class
    base_precision = 1e-3
Will Berman's avatar
Will Berman committed
209
    forward_requires_fresh_args = False
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    model_split_percents = [0.5, 0.7, 0.9]

    def check_device_map_is_respected(self, model, device_map):
        for param_name, param in model.named_parameters():
            # Find device in device_map
            while len(param_name) > 0 and param_name not in device_map:
                param_name = ".".join(param_name.split(".")[:-1])
            if param_name not in device_map:
                raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")

            param_device = device_map[param_name]
            if param_device in ["cpu", "disk"]:
                self.assertEqual(param.device, torch.device("meta"))
            else:
                self.assertEqual(param.device, torch.device(param_device))
225

226
    def test_from_save_pretrained(self, expected_max_diff=5e-5):
Will Berman's avatar
Will Berman committed
227
228
229
230
231
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
232

233
234
        if hasattr(model, "set_default_attn_processor"):
            model.set_default_attn_processor()
235
236
237
238
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
239
            model.save_pretrained(tmpdirname, safe_serialization=False)
240
            new_model = self.model_class.from_pretrained(tmpdirname)
241
242
            if hasattr(new_model, "set_default_attn_processor"):
                new_model.set_default_attn_processor()
243
244
245
            new_model.to(torch_device)

        with torch.no_grad():
Will Berman's avatar
Will Berman committed
246
247
248
249
250
            if self.forward_requires_fresh_args:
                image = model(**self.inputs_dict(0))
            else:
                image = model(**inputs_dict)

251
            if isinstance(image, dict):
252
                image = image.to_tuple()[0]
253

Will Berman's avatar
Will Berman committed
254
255
256
257
            if self.forward_requires_fresh_args:
                new_image = new_model(**self.inputs_dict(0))
            else:
                new_image = new_model(**inputs_dict)
258
259

            if isinstance(new_image, dict):
260
                new_image = new_image.to_tuple()[0]
261

262
263
        max_diff = (image - new_image).abs().max().item()
        self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def test_getattr_is_correct(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)

        # save some things to test
        model.dummy_attribute = 5
        model.register_to_config(test_attribute=5)

        logger = logging.get_logger("diffusers.models.modeling_utils")
        # 30 for warning
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            assert hasattr(model, "dummy_attribute")
            assert getattr(model, "dummy_attribute") == 5
            assert model.dummy_attribute == 5

        # no warning should be thrown
        assert cap_logger.out == ""

        logger = logging.get_logger("diffusers.models.modeling_utils")
        # 30 for warning
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            assert hasattr(model, "save_pretrained")
            fn = model.save_pretrained
            fn_1 = getattr(model, "save_pretrained")

            assert fn == fn_1
        # no warning should be thrown
        assert cap_logger.out == ""

        # warning should be thrown
        with self.assertWarns(FutureWarning):
            assert model.test_attribute == 5

        with self.assertWarns(FutureWarning):
            assert getattr(model, "test_attribute") == 5

        with self.assertRaises(AttributeError) as error:
            model.does_not_exist

        assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    @unittest.skipIf(
        torch_device != "npu" or not is_torch_npu_available(),
        reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
    )
    def test_set_torch_npu_flash_attn_processor_determinism(self):
        torch.use_deterministic_algorithms(False)
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
        model.to(torch_device)

        if not hasattr(model, "set_attn_processor"):
            # If not has `set_attn_processor`, skip test
            return

        model.set_default_attn_processor()
        assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
        with torch.no_grad():
            if self.forward_requires_fresh_args:
                output = model(**self.inputs_dict(0))[0]
            else:
                output = model(**inputs_dict)[0]

        model.enable_npu_flash_attention()
        assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
        with torch.no_grad():
            if self.forward_requires_fresh_args:
                output_2 = model(**self.inputs_dict(0))[0]
            else:
                output_2 = model(**inputs_dict)[0]

        model.set_attn_processor(AttnProcessorNPU())
        assert all(type(proc) == AttnProcessorNPU for proc in model.attn_processors.values())
        with torch.no_grad():
            if self.forward_requires_fresh_args:
                output_3 = model(**self.inputs_dict(0))[0]
            else:
                output_3 = model(**inputs_dict)[0]

        torch.use_deterministic_algorithms(True)

        assert torch.allclose(output, output_2, atol=self.base_precision)
        assert torch.allclose(output, output_3, atol=self.base_precision)
        assert torch.allclose(output_2, output_3, atol=self.base_precision)

Dhruv Nair's avatar
Dhruv Nair committed
355
356
357
358
359
360
    @unittest.skipIf(
        torch_device != "cuda" or not is_xformers_available(),
        reason="XFormers attention is only available with CUDA and `xformers` installed",
    )
    def test_set_xformers_attn_processor_for_determinism(self):
        torch.use_deterministic_algorithms(False)
Will Berman's avatar
Will Berman committed
361
362
363
364
365
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
Dhruv Nair's avatar
Dhruv Nair committed
366
367
368
369
370
371
372
373
374
        model.to(torch_device)

        if not hasattr(model, "set_attn_processor"):
            # If not has `set_attn_processor`, skip test
            return

        model.set_default_attn_processor()
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
375
376
377
378
            if self.forward_requires_fresh_args:
                output = model(**self.inputs_dict(0))[0]
            else:
                output = model(**inputs_dict)[0]
Dhruv Nair's avatar
Dhruv Nair committed
379
380
381
382

        model.enable_xformers_memory_efficient_attention()
        assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
383
384
385
386
            if self.forward_requires_fresh_args:
                output_2 = model(**self.inputs_dict(0))[0]
            else:
                output_2 = model(**inputs_dict)[0]
Dhruv Nair's avatar
Dhruv Nair committed
387

388
389
390
        model.set_attn_processor(XFormersAttnProcessor())
        assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
391
392
393
394
            if self.forward_requires_fresh_args:
                output_3 = model(**self.inputs_dict(0))[0]
            else:
                output_3 = model(**inputs_dict)[0]
395
396
397

        torch.use_deterministic_algorithms(True)

Dhruv Nair's avatar
Dhruv Nair committed
398
        assert torch.allclose(output, output_2, atol=self.base_precision)
399
400
        assert torch.allclose(output, output_3, atol=self.base_precision)
        assert torch.allclose(output_2, output_3, atol=self.base_precision)
Dhruv Nair's avatar
Dhruv Nair committed
401

402
403
404
    @require_torch_gpu
    def test_set_attn_processor_for_determinism(self):
        torch.use_deterministic_algorithms(False)
Will Berman's avatar
Will Berman committed
405
406
407
408
409
410
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)

411
412
413
414
415
416
417
418
        model.to(torch_device)

        if not hasattr(model, "set_attn_processor"):
            # If not has `set_attn_processor`, skip test
            return

        assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
419
420
421
422
            if self.forward_requires_fresh_args:
                output_1 = model(**self.inputs_dict(0))[0]
            else:
                output_1 = model(**inputs_dict)[0]
423
424
425
426

        model.set_default_attn_processor()
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
427
428
429
430
            if self.forward_requires_fresh_args:
                output_2 = model(**self.inputs_dict(0))[0]
            else:
                output_2 = model(**inputs_dict)[0]
431
432
433
434

        model.set_attn_processor(AttnProcessor2_0())
        assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
435
436
437
438
            if self.forward_requires_fresh_args:
                output_4 = model(**self.inputs_dict(0))[0]
            else:
                output_4 = model(**inputs_dict)[0]
439
440
441
442

        model.set_attn_processor(AttnProcessor())
        assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
443
444
445
446
            if self.forward_requires_fresh_args:
                output_5 = model(**self.inputs_dict(0))[0]
            else:
                output_5 = model(**inputs_dict)[0]
447
448
449
450
451
452
453
454

        torch.use_deterministic_algorithms(True)

        # make sure that outputs match
        assert torch.allclose(output_2, output_1, atol=self.base_precision)
        assert torch.allclose(output_2, output_4, atol=self.base_precision)
        assert torch.allclose(output_2, output_5, atol=self.base_precision)

455
    def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
Will Berman's avatar
Will Berman committed
456
457
458
459
460
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
461

462
463
        if hasattr(model, "set_default_attn_processor"):
            model.set_default_attn_processor()
464

465
466
467
468
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
469
            model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
470
            new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
471
472
            if hasattr(new_model, "set_default_attn_processor"):
                new_model.set_default_attn_processor()
473
474
475
476
477
478
479
480
481
482
483

            # non-variant cannot be loaded
            with self.assertRaises(OSError) as error_context:
                self.model_class.from_pretrained(tmpdirname)

            # make sure that error message states what keys are missing
            assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)

            new_model.to(torch_device)

        with torch.no_grad():
Will Berman's avatar
Will Berman committed
484
485
486
487
            if self.forward_requires_fresh_args:
                image = model(**self.inputs_dict(0))
            else:
                image = model(**inputs_dict)
488
            if isinstance(image, dict):
489
                image = image.to_tuple()[0]
490

Will Berman's avatar
Will Berman committed
491
492
493
494
            if self.forward_requires_fresh_args:
                new_image = new_model(**self.inputs_dict(0))
            else:
                new_image = new_model(**inputs_dict)
495
496

            if isinstance(new_image, dict):
497
                new_image = new_image.to_tuple()[0]
498

499
500
        max_diff = (image - new_image).abs().max().item()
        self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
501

Dhruv Nair's avatar
Dhruv Nair committed
502
    @require_python39_or_higher
503
    @require_torch_2
504
505
506
507
    @unittest.skipIf(
        get_python_version == (3, 12),
        reason="Torch Dynamo isn't yet supported for Python 3.12.",
    )
508
    def test_from_save_pretrained_dynamo(self):
509
510
511
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()
        inputs = [init_dict, self.model_class]
        run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs)
512

513
514
515
516
517
518
519
520
521
522
523
524
    def test_from_save_pretrained_dtype(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
            if torch_device == "mps" and dtype == torch.bfloat16:
                continue
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.to(dtype)
525
                model.save_pretrained(tmpdirname, safe_serialization=False)
526
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
527
                assert new_model.dtype == dtype
528
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
529
530
                assert new_model.dtype == dtype

531
    def test_determinism(self, expected_max_diff=1e-5):
Will Berman's avatar
Will Berman committed
532
533
534
535
536
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
537
538
        model.to(torch_device)
        model.eval()
539

540
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
541
542
543
544
            if self.forward_requires_fresh_args:
                first = model(**self.inputs_dict(0))
            else:
                first = model(**inputs_dict)
545
            if isinstance(first, dict):
546
                first = first.to_tuple()[0]
547

Will Berman's avatar
Will Berman committed
548
549
550
551
            if self.forward_requires_fresh_args:
                second = model(**self.inputs_dict(0))
            else:
                second = model(**inputs_dict)
552
            if isinstance(second, dict):
553
                second = second.to_tuple()[0]
554
555
556
557
558
559

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
560
        self.assertLessEqual(max_diff, expected_max_diff)
561

562
    def test_output(self, expected_output_shape=None):
563
564
565
566
567
568
569
570
571
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
572
                output = output.to_tuple()[0]
573
574

        self.assertIsNotNone(output)
575

576
577
        # input & output have to have the same shape
        input_tensor = inputs_dict[self.main_input_name]
578
579
580
581
582
583

        if expected_output_shape is None:
            expected_shape = input_tensor.shape
            self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
        else:
            self.assertEqual(output.shape, expected_output_shape, "Input and output shapes do not match")
584

585
    def test_model_from_pretrained(self):
586
587
588
589
590
591
592
593
594
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
595
            model.save_pretrained(tmpdirname, safe_serialization=False)
596
            new_model = self.model_class.from_pretrained(tmpdirname)
597
598
599
            new_model.to(torch_device)
            new_model.eval()

600
        # check if all parameters shape are the same
601
602
603
604
605
606
607
608
609
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)

        with torch.no_grad():
            output_1 = model(**inputs_dict)

            if isinstance(output_1, dict):
610
                output_1 = output_1.to_tuple()[0]
611
612
613
614

            output_2 = new_model(**inputs_dict)

            if isinstance(output_2, dict):
615
                output_2 = output_2.to_tuple()[0]
616
617
618

        self.assertEqual(output_1.shape, output_2.shape)

Arsalan's avatar
Arsalan committed
619
    @require_torch_accelerator_with_training
620
621
622
623
624
625
626
627
628
    def test_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)

        if isinstance(output, dict):
629
            output = output.to_tuple()[0]
630

631
632
        input_tensor = inputs_dict[self.main_input_name]
        noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
633
634
635
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()

Arsalan's avatar
Arsalan committed
636
    @require_torch_accelerator_with_training
637
638
639
640
641
642
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
643
        ema_model = EMAModel(model.parameters())
644
645
646
647

        output = model(**inputs_dict)

        if isinstance(output, dict):
648
            output = output.to_tuple()[0]
649

650
651
        input_tensor = inputs_dict[self.main_input_name]
        noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
652
653
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
654
        ema_model.step(model.parameters())
655

656
    def test_outputs_equivalence(self):
657
        def set_nan_tensor_to_zero(t):
658
659
660
661
662
            # Temporary fallback until `aten::_index_put_impl_` is implemented in mps
            # Track progress in https://github.com/pytorch/pytorch/issues/77764
            device = t.device
            if device.type == "mps":
                t = t.to("cpu")
663
            t[t != t] = 0
664
            return t.to(device)
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687

        def recursive_check(tuple_object, dict_object):
            if isinstance(tuple_object, (List, Tuple)):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif isinstance(tuple_object, Dict):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif tuple_object is None:
                return
            else:
                self.assertTrue(
                    torch.allclose(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                    ),
                    msg=(
                        "Tuple and dict output are not equal. Difference:"
                        f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                        f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                        f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                    ),
                )

Will Berman's avatar
Will Berman committed
688
689
690
691
692
        if self.forward_requires_fresh_args:
            model = self.model_class(**self.init_dict)
        else:
            init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
            model = self.model_class(**init_dict)
693
694
695
696

        model.to(torch_device)
        model.eval()

697
        with torch.no_grad():
Will Berman's avatar
Will Berman committed
698
699
700
701
702
703
            if self.forward_requires_fresh_args:
                outputs_dict = model(**self.inputs_dict(0))
                outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
            else:
                outputs_dict = model(**inputs_dict)
                outputs_tuple = model(**inputs_dict, return_dict=False)
704
705

        recursive_check(outputs_tuple, outputs_dict)
706

Arsalan's avatar
Arsalan committed
707
    @require_torch_accelerator_with_training
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    def test_enable_disable_gradient_checkpointing(self):
        if not self.model_class._supports_gradient_checkpointing:
            return  # Skip test if model does not support gradient checkpointing

        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        # at init model should have gradient checkpointing disabled
        model = self.model_class(**init_dict)
        self.assertFalse(model.is_gradient_checkpointing)

        # check enable works
        model.enable_gradient_checkpointing()
        self.assertTrue(model.is_gradient_checkpointing)

        # check disable works
        model.disable_gradient_checkpointing()
        self.assertFalse(model.is_gradient_checkpointing)
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744

    def test_deprecated_kwargs(self):
        has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
        has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

        if has_kwarg_in_model_class and not has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
                " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
                " [<deprecated_argument>]`"
            )

        if not has_kwarg_in_model_class and has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
                f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
                " from `_deprecated_kwargs = [<deprecated_argument>]`"
            )
745

746
747
748
749
    @require_torch_gpu
    def test_cpu_offload(self):
        config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**config).eval()
750
751
752
        if model._no_split_modules is None:
            return

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        model = model.to(torch_device)

        torch.manual_seed(0)
        base_output = model(**inputs_dict)

        model_size = compute_module_sizes(model)[""]
        # We test several splits of sizes to make sure it works.
        max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.cpu().save_pretrained(tmp_dir)

            for max_size in max_gpu_sizes:
                max_memory = {0: max_size, "cpu": model_size * 2}
                new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
                # Making sure part of the model will actually end up offloaded
                self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})

                self.check_device_map_is_respected(new_model, new_model.hf_device_map)
                torch.manual_seed(0)
                new_output = new_model(**inputs_dict)

                self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

    @require_torch_gpu
    def test_disk_offload_without_safetensors(self):
        config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**config).eval()
780
781
782
        if model._no_split_modules is None:
            return

783
784
785
786
787
788
789
790
791
792
        model = model.to(torch_device)

        torch.manual_seed(0)
        base_output = model(**inputs_dict)

        model_size = compute_module_sizes(model)[""]
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

            with self.assertRaises(ValueError):
793
                max_size = int(self.model_split_percents[0] * model_size)
794
795
796
797
                max_memory = {0: max_size, "cpu": max_size}
                # This errors out because it's missing an offload folder
                new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

798
            max_size = int(self.model_split_percents[0] * model_size)
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
            max_memory = {0: max_size, "cpu": max_size}
            new_model = self.model_class.from_pretrained(
                tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
            )

            self.check_device_map_is_respected(new_model, new_model.hf_device_map)
            torch.manual_seed(0)
            new_output = new_model(**inputs_dict)

            self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

    @require_torch_gpu
    def test_disk_offload_with_safetensors(self):
        config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**config).eval()
814
815
816
        if model._no_split_modules is None:
            return

817
818
819
820
821
822
823
824
825
        model = model.to(torch_device)

        torch.manual_seed(0)
        base_output = model(**inputs_dict)

        model_size = compute_module_sizes(model)[""]
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.cpu().save_pretrained(tmp_dir)

826
            max_size = int(self.model_split_percents[0] * model_size)
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
            max_memory = {0: max_size, "cpu": max_size}
            new_model = self.model_class.from_pretrained(
                tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
            )

            self.check_device_map_is_respected(new_model, new_model.hf_device_map)
            torch.manual_seed(0)
            new_output = new_model(**inputs_dict)

            self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

    @require_torch_multi_gpu
    def test_model_parallelism(self):
        config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**config).eval()
842
843
844
        if model._no_split_modules is None:
            return

845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
        model = model.to(torch_device)

        torch.manual_seed(0)
        base_output = model(**inputs_dict)

        model_size = compute_module_sizes(model)[""]
        # We test several splits of sizes to make sure it works.
        max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.cpu().save_pretrained(tmp_dir)

            for max_size in max_gpu_sizes:
                max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
                new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
                # Making sure part of the model will actually end up offloaded
                self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})

                self.check_device_map_is_respected(new_model, new_model.hf_device_map)

                torch.manual_seed(0)
                new_output = new_model(**inputs_dict)

                self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936

@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
    identifier = uuid.uuid4()
    repo_id = f"test-model-{identifier}"
    org_repo_id = f"valid_org/{repo_id}-org"

    def test_push_to_hub(self):
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        model.push_to_hub(self.repo_id, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(token=TOKEN, repo_id=self.repo_id)

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(self.repo_id, token=TOKEN)

    def test_push_to_hub_in_organization(self):
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        model.push_to_hub(self.org_repo_id, token=TOKEN)

        new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(token=TOKEN, repo_id=self.org_repo_id)

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)

        new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
        for p1, p2 in zip(model.parameters(), new_model.parameters()):
            self.assertTrue(torch.equal(p1, p2))

        # Reset repo
        delete_repo(self.org_repo_id, token=TOKEN)
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959

    @unittest.skipIf(
        not is_jinja_available(),
        reason="Model card tests cannot be performed without Jinja installed.",
    )
    def test_push_to_hub_library_name(self):
        model = UNet2DConditionModel(
            block_out_channels=(32, 64),
            layers_per_block=2,
            sample_size=32,
            in_channels=4,
            out_channels=4,
            down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
            up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
            cross_attention_dim=32,
        )
        model.push_to_hub(self.repo_id, token=TOKEN)

        model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
        assert model_card.library_name == "diffusers"

        # Reset repo
        delete_repo(self.repo_id, token=TOKEN)