test_modeling_common.py 16.9 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 HuggingFace Inc.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import tempfile
18
import unittest
19
import unittest.mock as mock
20
from typing import Dict, List, Tuple
21
22

import numpy as np
23
import requests_mock
24
import torch
25
from requests.exceptions import HTTPError
26

27
from diffusers.models import ModelMixin, UNet2DConditionModel
28
from diffusers.models.attention_processor import AttnProcessor
29
from diffusers.training_utils import EMAModel
30
from diffusers.utils import torch_device
31
32


33
class ModelUtilsTest(unittest.TestCase):
34
35
36
37
38
39
40
    def tearDown(self):
        super().tearDown()

        import diffusers

        diffusers.utils.import_utils._safetensors_available = True

41
42
43
44
45
46
47
    def test_accelerate_loading_error_message(self):
        with self.assertRaises(ValueError) as error_context:
            UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")

        # make sure that error message states what keys are missing
        assert "conv_out.bias" in str(error_context.exception)

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def test_cached_files_are_used_when_no_internet(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        orig_model = UNet2DConditionModel.from_pretrained(
            "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
        )

        # Under the mock environment we get a 500 error when trying to reach the model.
        with mock.patch("requests.request", return_value=response_mock):
            # Download this model to make sure it's in the cache.
            model = UNet2DConditionModel.from_pretrained(
                "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
            )

        for p1, p2 in zip(orig_model.parameters(), model.parameters()):
            if p1.data.ne(p2.data).sum() > 0:
                assert False, "Parameters not the same!"

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def test_one_request_upon_cached(self):
        # TODO: For some reason this test fails on MPS where no HEAD call is made.
        if torch_device == "mps":
            return

        import diffusers

        diffusers.utils.import_utils._safetensors_available = False

        with tempfile.TemporaryDirectory() as tmpdirname:
            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
                    "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
                )

            download_requests = [r.method for r in m.request_history]
            assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model"
            assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"

            with requests_mock.mock(real_http=True) as m:
                UNet2DConditionModel.from_pretrained(
                    "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname
                )

            cache_requests = [r.method for r in m.request_history]
            assert (
                "HEAD" == cache_requests[0] and len(cache_requests) == 1
            ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"

        diffusers.utils.import_utils._safetensors_available = True

103

104
class ModelTesterMixin:
105
    def test_from_save_pretrained(self):
106
107
108
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
109
110
        if hasattr(model, "set_attn_processor"):
            model.set_attn_processor(AttnProcessor())
111
112
113
114
115
116
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname)
            new_model = self.model_class.from_pretrained(tmpdirname)
117
118
            if hasattr(new_model, "set_attn_processor"):
                new_model.set_attn_processor(AttnProcessor())
119
120
121
            new_model.to(torch_device)

        with torch.no_grad():
122
123
124
125
126
            # Warmup pass when using mps (see #372)
            if torch_device == "mps" and isinstance(model, ModelMixin):
                _ = model(**self.dummy_input)
                _ = new_model(**self.dummy_input)

127
128
            image = model(**inputs_dict)
            if isinstance(image, dict):
129
                image = image.sample
130
131
132
133

            new_image = new_model(**inputs_dict)

            if isinstance(new_image, dict):
134
                new_image = new_image.sample
135
136
137

        max_diff = (image - new_image).abs().sum().item()
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
138
139
140
141
142

    def test_from_save_pretrained_variant(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
143
144
        if hasattr(model, "set_attn_processor"):
            model.set_attn_processor(AttnProcessor())
145
146
147
148
149
150
        model.to(torch_device)
        model.eval()

        with tempfile.TemporaryDirectory() as tmpdirname:
            model.save_pretrained(tmpdirname, variant="fp16")
            new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
151
152
            if hasattr(new_model, "set_attn_processor"):
                new_model.set_attn_processor(AttnProcessor())
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

            # non-variant cannot be loaded
            with self.assertRaises(OSError) as error_context:
                self.model_class.from_pretrained(tmpdirname)

            # make sure that error message states what keys are missing
            assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)

            new_model.to(torch_device)

        with torch.no_grad():
            # Warmup pass when using mps (see #372)
            if torch_device == "mps" and isinstance(model, ModelMixin):
                _ = model(**self.dummy_input)
                _ = new_model(**self.dummy_input)

            image = model(**inputs_dict)
            if isinstance(image, dict):
                image = image.sample

            new_image = new_model(**inputs_dict)

            if isinstance(new_image, dict):
                new_image = new_image.sample

        max_diff = (image - new_image).abs().sum().item()
        self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
180

181
182
183
184
185
186
187
188
189
190
191
192
193
    def test_from_save_pretrained_dtype(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
            if torch_device == "mps" and dtype == torch.bfloat16:
                continue
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.to(dtype)
                model.save_pretrained(tmpdirname)
194
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
195
                assert new_model.dtype == dtype
196
                new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
197
198
                assert new_model.dtype == dtype

199
200
201
202
203
    def test_determinism(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()
204

205
        with torch.no_grad():
206
207
208
209
            # Warmup pass when using mps (see #372)
            if torch_device == "mps" and isinstance(model, ModelMixin):
                model(**self.dummy_input)

210
211
            first = model(**inputs_dict)
            if isinstance(first, dict):
212
                first = first.sample
213
214
215

            second = model(**inputs_dict)
            if isinstance(second, dict):
216
                second = second.sample
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        out_1 = first.cpu().numpy()
        out_2 = second.cpu().numpy()
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)

    def test_output(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
235
                output = output.sample
236
237
238
239
240

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    def test_forward_with_norm_groups(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        init_dict["norm_num_groups"] = 16
        init_dict["block_out_channels"] = (16, 32)

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        with torch.no_grad():
            output = model(**inputs_dict)

            if isinstance(output, dict):
                output = output.sample

        self.assertIsNotNone(output)
        expected_shape = inputs_dict["sample"].shape
        self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

261
262
263
264
265
266
267
268
269
270
271
    def test_forward_signature(self):
        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        signature = inspect.signature(model.forward)
        # signature.parameters is an OrderedDict => so arg_names order is deterministic
        arg_names = [*signature.parameters.keys()]

        expected_arg_names = ["sample", "timestep"]
        self.assertListEqual(arg_names[:2], expected_arg_names)

272
    def test_model_from_pretrained(self):
273
274
275
276
277
278
279
280
281
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

        # test if the model can be loaded from the config
        # and has all the expected shape
        with tempfile.TemporaryDirectory() as tmpdirname:
282
283
            model.save_pretrained(tmpdirname)
            new_model = self.model_class.from_pretrained(tmpdirname)
284
285
286
            new_model.to(torch_device)
            new_model.eval()

287
        # check if all parameters shape are the same
288
289
290
291
292
293
294
295
296
        for param_name in model.state_dict().keys():
            param_1 = model.state_dict()[param_name]
            param_2 = new_model.state_dict()[param_name]
            self.assertEqual(param_1.shape, param_2.shape)

        with torch.no_grad():
            output_1 = model(**inputs_dict)

            if isinstance(output_1, dict):
297
                output_1 = output_1.sample
298
299
300
301

            output_2 = new_model(**inputs_dict)

            if isinstance(output_2, dict):
302
                output_2 = output_2.sample
303
304
305

        self.assertEqual(output_1.shape, output_2.shape)

306
    @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
307
308
309
310
311
312
313
314
315
    def test_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
        output = model(**inputs_dict)

        if isinstance(output, dict):
316
            output = output.sample
317
318
319
320
321

        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()

322
    @unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
323
324
325
326
327
328
    def test_ema_training(self):
        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.train()
329
        ema_model = EMAModel(model.parameters())
330
331
332
333

        output = model(**inputs_dict)

        if isinstance(output, dict):
334
            output = output.sample
335
336
337
338

        noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
        loss = torch.nn.functional.mse_loss(output, noise)
        loss.backward()
339
        ema_model.step(model.parameters())
340

341
    def test_outputs_equivalence(self):
342
        def set_nan_tensor_to_zero(t):
343
344
345
346
347
            # Temporary fallback until `aten::_index_put_impl_` is implemented in mps
            # Track progress in https://github.com/pytorch/pytorch/issues/77764
            device = t.device
            if device.type == "mps":
                t = t.to("cpu")
348
            t[t != t] = 0
349
            return t.to(device)
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

        def recursive_check(tuple_object, dict_object):
            if isinstance(tuple_object, (List, Tuple)):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif isinstance(tuple_object, Dict):
                for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
                    recursive_check(tuple_iterable_value, dict_iterable_value)
            elif tuple_object is None:
                return
            else:
                self.assertTrue(
                    torch.allclose(
                        set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                    ),
                    msg=(
                        "Tuple and dict output are not equal. Difference:"
                        f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                        f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                        f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                    ),
                )

        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

        model = self.model_class(**init_dict)
        model.to(torch_device)
        model.eval()

379
380
381
382
383
384
385
        with torch.no_grad():
            # Warmup pass when using mps (see #372)
            if torch_device == "mps" and isinstance(model, ModelMixin):
                model(**self.dummy_input)

            outputs_dict = model(**inputs_dict)
            outputs_tuple = model(**inputs_dict, return_dict=False)
386
387

        recursive_check(outputs_tuple, outputs_dict)
388

Anton Lozhkov's avatar
Anton Lozhkov committed
389
    @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    def test_enable_disable_gradient_checkpointing(self):
        if not self.model_class._supports_gradient_checkpointing:
            return  # Skip test if model does not support gradient checkpointing

        init_dict, _ = self.prepare_init_args_and_inputs_for_common()

        # at init model should have gradient checkpointing disabled
        model = self.model_class(**init_dict)
        self.assertFalse(model.is_gradient_checkpointing)

        # check enable works
        model.enable_gradient_checkpointing()
        self.assertTrue(model.is_gradient_checkpointing)

        # check disable works
        model.disable_gradient_checkpointing()
        self.assertFalse(model.is_gradient_checkpointing)
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426

    def test_deprecated_kwargs(self):
        has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
        has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

        if has_kwarg_in_model_class and not has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
                " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
                " [<deprecated_argument>]`"
            )

        if not has_kwarg_in_model_class and has_deprecated_kwarg:
            raise ValueError(
                f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
                " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
                f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
                " from `_deprecated_kwargs = [<deprecated_argument>]`"
            )