"tests/models/albert/test_modeling_albert.py" did not exist on "5deed37f9f1a0f5794a2a7cd02164ff265c59524"
test_modeling_common.py 209 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
import collections
16
import copy
17
import gc
18
import inspect
19
import os
20
import os.path
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import random
Sylvain Gugger's avatar
Sylvain Gugger committed
22
import re
23
import tempfile
24
import warnings
25
from collections import defaultdict
NielsRogge's avatar
NielsRogge committed
26
from typing import Dict, List, Tuple
thomwolf's avatar
thomwolf committed
27

28
import numpy as np
29
from packaging import version
30
from parameterized import parameterized
31
from pytest import mark
32
33

import transformers
34
35
from transformers import (
    AutoModel,
36
    AutoModelForCausalLM,
37
    AutoModelForSequenceClassification,
38
    AutoTokenizer,
39
    PretrainedConfig,
40
    PreTrainedModel,
41
42
    is_torch_available,
    logging,
43
    set_seed,
44
)
45
from transformers.models.auto import get_values
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES,
    MODEL_FOR_BACKBONE_MAPPING_NAMES,
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
amyeroberts's avatar
amyeroberts committed
64
    MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
65
66
    MODEL_MAPPING_NAMES,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
67
68
from transformers.testing_utils import (
    CaptureLogger,
69
    is_flaky,
70
71
    is_pt_flax_cross_test,
    is_pt_tf_cross_test,
72
    require_accelerate,
73
    require_bitsandbytes,
74
    require_flash_attn,
75
    require_read_token,
76
    require_safetensors,
Sylvain Gugger's avatar
Sylvain Gugger committed
77
    require_torch,
78
    require_torch_gpu,
79
    require_torch_multi_accelerator,
Sylvain Gugger's avatar
Sylvain Gugger committed
80
    require_torch_multi_gpu,
81
    require_torch_sdpa,
Sylvain Gugger's avatar
Sylvain Gugger committed
82
83
84
    slow,
    torch_device,
)
85
from transformers.utils import (
86
87
    CONFIG_NAME,
    GENERATION_CONFIG_NAME,
88
    SAFE_WEIGHTS_NAME,
89
    is_accelerate_available,
90
91
    is_flax_available,
    is_tf_available,
fxmarty's avatar
fxmarty committed
92
93
    is_torch_bf16_available_on_device,
    is_torch_fp16_available_on_device,
94
    is_torch_fx_available,
95
    is_torch_sdpa_available,
96
)
97
from transformers.utils.generic import ContextManagers, ModelOutput
98

Aymeric Augustin's avatar
Aymeric Augustin committed
99

100
101
102
103
if is_accelerate_available():
    from accelerate.utils import compute_module_sizes


104
if is_torch_available():
105
    import torch
106
    import torch.nn.functional as F
107
    from safetensors.torch import load_file as safe_load_file
108
    from safetensors.torch import save_file as safe_save_file
109
    from torch import nn
thomwolf's avatar
thomwolf committed
110

111
    from transformers import MODEL_MAPPING, AdaptiveEmbedding
112
    from transformers.modeling_utils import load_state_dict, no_init_weights
Sylvain Gugger's avatar
Sylvain Gugger committed
113
    from transformers.pytorch_utils import id_tensor_storage
thomwolf's avatar
thomwolf committed
114

Sylvain Gugger's avatar
Sylvain Gugger committed
115

116
117
118
if is_tf_available():
    import tensorflow as tf

119
120
if is_flax_available():
    import jax.numpy as jnp
121

122
    from tests.utils.test_modeling_flax_utils import check_models_equal
123
124
125
126
127
    from transformers.modeling_flax_pytorch_utils import (
        convert_pytorch_state_dict_to_flax,
        load_flax_weights_in_pytorch_model,
    )

128
if is_torch_fx_available():
129
    from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
130

131

132
133
134
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
135
        if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
Lysandre Debut's avatar
Lysandre Debut committed
136
            setattr(configs_no_init, key, 1e-10)
137
138
139
        if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
            no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
            setattr(configs_no_init, key, no_init_subconfig)
140
141
    return configs_no_init

thomwolf's avatar
thomwolf committed
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def _mock_init_weights(self, module):
    for name, param in module.named_parameters(recurse=False):
        # Use the first letter of the name to get a value and go from a <> -13 to z <> 12
        value = ord(name[0].lower()) - 110
        param.data.fill_(value)


def _mock_all_init_weights(self):
    # Prune heads if needed
    if self.config.pruned_heads:
        self.prune_heads(self.config.pruned_heads)

    import transformers.modeling_utils

    if transformers.modeling_utils._init_weights:
        for module in self.modules():
            module._is_hf_initialized = False
        # Initialize weights
        self.apply(self._initialize_weights)

        # Tie weights should be skipped when not initializing all weights
        # since from_pretrained(...) calls tie weights anyways
        self.tie_weights()


168
169
170
171
@require_torch
class ModelTesterMixin:
    model_tester = None
    all_model_classes = ()
172
    all_generative_model_classes = ()
173
    fx_compatible = False
Patrick von Platen's avatar
Patrick von Platen committed
174
175
176
    test_torchscript = True
    test_pruning = True
    test_resize_embeddings = True
177
    test_resize_position_embeddings = False
Patrick von Platen's avatar
Patrick von Platen committed
178
    test_head_masking = True
179
    test_mismatched_shapes = True
180
    test_missing_keys = True
181
    test_model_parallel = False
182
    is_encoder_decoder = False
183
    has_attentions = True
184
    model_split_percents = [0.5, 0.7, 0.9]
185

186
187
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = copy.deepcopy(inputs_dict)
188
        if model_class.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
189
            inputs_dict = {
190
                k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
191
                if isinstance(v, torch.Tensor) and v.ndim > 1
Sylvain Gugger's avatar
Sylvain Gugger committed
192
                else v
193
194
                for k, v in inputs_dict.items()
            }
195
        elif model_class.__name__ in get_values(MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES):
196
            inputs_dict.pop("attention_mask")
197
198

        if return_labels:
199
            if model_class.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
200
                inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
201
202
203
            elif model_class.__name__ in [
                *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
                *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
204
            ]:
205
206
207
208
209
210
                inputs_dict["start_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
                inputs_dict["end_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
211
212
213
214
215
216
            elif model_class.__name__ in [
                *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
217
            ]:
218
219
220
                inputs_dict["labels"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
221
222
223
224
225
226
            elif model_class.__name__ in [
                *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES),
                *get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
                *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
amyeroberts's avatar
amyeroberts committed
227
                *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
228
229
230
231
            ]:
                inputs_dict["labels"] = torch.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
                )
232
            elif model_class.__name__ in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES):
NielsRogge's avatar
NielsRogge committed
233
234
235
236
                num_patches = self.model_tester.image_size // self.model_tester.patch_size
                inputs_dict["bool_masked_pos"] = torch.zeros(
                    (self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
                )
237
            elif model_class.__name__ in get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES):
NielsRogge's avatar
NielsRogge committed
238
239
240
241
                batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
                inputs_dict["labels"] = torch.zeros(
                    [self.model_tester.batch_size, height, width], device=torch_device
                ).long()
242

243
244
        return inputs_dict

Patrick von Platen's avatar
Patrick von Platen committed
245
    def test_save_load(self):
246
247
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

248
249
250
251
252
253
254
255
256
257
        def check_save_load(out1, out2):
            # make sure we don't have nans
            out_2 = out2.cpu().numpy()
            out_2[np.isnan(out_2)] = 0

            out_1 = out1.cpu().numpy()
            out_1[np.isnan(out_1)] = 0
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

258
259
260
261
262
        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
263
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
264

265
            with tempfile.TemporaryDirectory() as tmpdirname:
266
                model.save_pretrained(tmpdirname)
267
268
269
270
271
272
273

                # the config file (and the generation config file, if it can generate) should be saved
                self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
                self.assertEqual(
                    model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
                )

274
                model = model_class.from_pretrained(tmpdirname)
275
                model.to(torch_device)
276
                with torch.no_grad():
277
                    second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
thomwolf's avatar
thomwolf committed
278

279
280
281
282
283
            if isinstance(first, tuple) and isinstance(second, tuple):
                for tensor1, tensor2 in zip(first, second):
                    check_save_load(tensor1, tensor2)
            else:
                check_save_load(first, second)
284

285
286
287
288
289
290
291
292
293
294
295
296
    def test_from_pretrained_no_checkpoint(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            model = model_class(config)
            state_dict = model.state_dict()

            new_model = model_class.from_pretrained(
                pretrained_model_name_or_path=None, config=config, state_dict=state_dict
            )
            for p1, p2 in zip(model.parameters(), new_model.parameters()):
                self.assertTrue(torch.equal(p1, p2))

297
298
299
300
    def test_keep_in_fp32_modules(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            if model_class._keep_in_fp32_modules is None:
amyeroberts's avatar
amyeroberts committed
301
                self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined")
302
303
304
305
306
307
308
309
310
311
312
313
314

            model = model_class(config)
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)

                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)

                for name, param in model.named_parameters():
                    if any(n in model_class._keep_in_fp32_modules for n in name.split(".")):
                        self.assertTrue(param.dtype == torch.float32)
                    else:
                        self.assertTrue(param.dtype == torch.float16, name)

315
    def test_save_load_keys_to_ignore_on_save(self):
316
317
318
319
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
320
321
            _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
            if _keys_to_ignore_on_save is None:
322
323
324
                continue

            # check the keys are in the original state_dict
325
            for k in _keys_to_ignore_on_save:
326
                self.assertIn(k, model.state_dict().keys(), "\n".join(model.state_dict().keys()))
327
328
329
330

            # check that certain keys didn't get saved with the model
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
331
332
333
                output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME)
                state_dict_saved = safe_load_file(output_model_file)

334
                for k in _keys_to_ignore_on_save:
335
                    self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
336

Sylvain Gugger's avatar
Sylvain Gugger committed
337
338
                # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
                load_result = model.load_state_dict(state_dict_saved, strict=False)
339
340
341
342
343
344
                keys_to_ignore = set(model._keys_to_ignore_on_save)

                if hasattr(model, "_tied_weights_keys"):
                    keys_to_ignore.update(set(model._tied_weights_keys))

                self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore)
Sylvain Gugger's avatar
Sylvain Gugger committed
345
346
                self.assertTrue(len(load_result.unexpected_keys) == 0)

347
348
349
350
351
352
353
354
355
356
357
    def test_gradient_checkpointing_backward_compatibility(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if not model_class.supports_gradient_checkpointing:
                continue

            config.gradient_checkpointing = True
            model = model_class(config)
            self.assertTrue(model.is_gradient_checkpointing)

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def test_gradient_checkpointing_enable_disable(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if not model_class.supports_gradient_checkpointing:
                continue

            # at init model should have gradient checkpointing disabled
            model = model_class(config)
            self.assertFalse(model.is_gradient_checkpointing)

            # check enable works
            model.gradient_checkpointing_enable()
            self.assertTrue(model.is_gradient_checkpointing)

373
374
375
376
377
378
379
            # Loop over all modules and check that relevant modules have gradient_checkpointing set to True
            for n, m in model.named_modules():
                if hasattr(m, "gradient_checkpointing"):
                    self.assertTrue(
                        m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
                    )

380
381
382
383
            # check disable works
            model.gradient_checkpointing_disable()
            self.assertFalse(model.is_gradient_checkpointing)

384
385
386
387
388
389
390
            # Loop over all modules and check that relevant modules have gradient_checkpointing set to False
            for n, m in model.named_modules():
                if hasattr(m, "gradient_checkpointing"):
                    self.assertFalse(
                        m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
                    )

391
    @is_flaky(description="low likelihood of failure, reason not yet discovered")
392
393
    def test_save_load_fast_init_from_base(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
394
        if config.__class__ not in MODEL_MAPPING:
amyeroberts's avatar
amyeroberts committed
395
396
            self.skipTest(reason="Model class not in MODEL_MAPPING")

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        base_class = MODEL_MAPPING[config.__class__]

        if isinstance(base_class, tuple):
            base_class = base_class[0]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            # make a copy of model class to not break future tests
            # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
            class CopyClass(model_class):
                pass

            model_class_copy = CopyClass

            # make sure that all keys are expected for test
            model_class_copy._keys_to_ignore_on_load_missing = []

            # make init deterministic, but make sure that
            # non-initialized weights throw errors nevertheless
418
419
            model_class_copy._init_weights = _mock_init_weights
            model_class_copy.init_weights = _mock_all_init_weights
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

            model = base_class(config)
            state_dict = model.state_dict()

            # this will often delete a single weight of a multi-weight module
            # to test an edge case
            random_key_to_del = random.choice(list(state_dict.keys()))
            del state_dict[random_key_to_del]

            # check that certain keys didn't get saved with the model
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

                model_fast_init = model_class_copy.from_pretrained(tmpdirname)
                model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
436
                # Before we test anything
437
438

                for key in model_fast_init.state_dict().keys():
439
440
441
442
443
                    if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
                        max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
                    else:
                        max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    @slow
    @require_accelerate
    @mark.accelerate_tests
    def test_save_load_low_cpu_mem_usage(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        with tempfile.TemporaryDirectory() as saved_model_path:
            for model_class in self.all_model_classes:
                model_to_save = model_class(config)
                model_to_save.save_pretrained(saved_model_path)

                self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)

    @slow
    @require_accelerate
    @mark.accelerate_tests
    def test_save_load_low_cpu_mem_usage_checkpoints(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        with tempfile.TemporaryDirectory() as saved_model_path:
            for model_class in self.all_model_classes:
                model_to_save = model_class(config)
                model_to_save.config.save_pretrained(saved_model_path)
                torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin"))

                self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)

    @slow
    @require_accelerate
    @mark.accelerate_tests
    def test_save_load_low_cpu_mem_usage_no_safetensors(self):
        with tempfile.TemporaryDirectory() as saved_model_path:
            for model_class in self.all_model_classes:
                config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
                model_to_save = model_class(config)

                model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
                self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)

    def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
483
484
        from accelerate.utils.modeling import named_module_tensors

485
486
487
488
489
490
491
492
493
494
495
496
497
498
        # Load the low usage and the normal models.
        model_low_usage, loading_info = model_class.from_pretrained(
            saved_model_path,
            low_cpu_mem_usage=True,
            output_loading_info=True,
        )
        model_non_low_usage = model_class.from_pretrained(saved_model_path)

        # Check that there were no missing keys.
        self.assertEqual(loading_info["missing_keys"], [])

        # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
        # subsequently loaded with the correct values and onto the correct device. We check if there are any
        # remaining params that were not properly loaded.
499
        for name, tensor in named_module_tensors(model_low_usage, recurse=True):
500
            self.assertNotEqual(
501
                tensor.device,
502
                torch.device("meta"),
503
                "Tensor '" + name + "' has not been properly loaded and has device=meta.",
504
505
506
507
            )

        # Check that the parameters are equal.
        for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
Arthur's avatar
Arthur committed
508
            self.assertEqual(p1.data.ne(p2.data).sum(), 0)
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525

        # Check that the state dict keys are equal.
        self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys()))

        # Check that the shared tensors are equal.
        tensor_ptrs1 = collections.defaultdict(list)
        for name, tensor in model_low_usage.state_dict().items():
            tensor_ptrs1[id_tensor_storage(tensor)].append(name)
        tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1]

        tensor_ptrs2 = collections.defaultdict(list)
        for name, tensor in model_non_low_usage.state_dict().items():
            tensor_ptrs2[id_tensor_storage(tensor)].append(name)
        tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1]

        self.assertEqual(tied_params1, tied_params2)

526
527
    def test_save_load_fast_init_to_base(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
528
        if config.__class__ not in MODEL_MAPPING:
amyeroberts's avatar
amyeroberts committed
529
530
            self.skipTest(reason="Model class not in MODEL_MAPPING")

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        base_class = MODEL_MAPPING[config.__class__]

        if isinstance(base_class, tuple):
            base_class = base_class[0]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            # make a copy of model class to not break future tests
            # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
            class CopyClass(base_class):
                pass

            base_class_copy = CopyClass

            # make sure that all keys are expected for test
            base_class_copy._keys_to_ignore_on_load_missing = []

            # make init deterministic, but make sure that
            # non-initialized weights throw errors nevertheless
552
553
            base_class_copy._init_weights = _mock_init_weights
            base_class_copy.init_weights = _mock_all_init_weights
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571

            model = model_class(config)
            state_dict = model.state_dict()

            # this will often delete a single weight of a multi-weight module
            # to test an edge case
            random_key_to_del = random.choice(list(state_dict.keys()))
            del state_dict[random_key_to_del]

            # check that certain keys didn't get saved with the model
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.config.save_pretrained(tmpdirname)
                torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

                model_fast_init = base_class_copy.from_pretrained(tmpdirname)
                model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)

                for key in model_fast_init.state_dict().keys():
572
573
574
575
576
577
578
579
580
                    if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
                        max_diff = torch.max(
                            model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
                        ).item()
                    else:
                        max_diff = torch.max(
                            torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
                        ).item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
581

582
583
584
    def test_torch_save_load(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        if config.__class__ not in MODEL_MAPPING:
amyeroberts's avatar
amyeroberts committed
585
586
            self.skipTest(reason="Model class not in MODEL_MAPPING")

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        base_class = MODEL_MAPPING[config.__class__]

        if isinstance(base_class, tuple):
            base_class = base_class[0]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            # make a copy of model class to not break future tests
            # from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
            class CopyClass(base_class):
                pass

            base_class_copy = CopyClass

            # make sure that all keys are expected for test
            base_class_copy._keys_to_ignore_on_load_missing = []

            # make init deterministic, but make sure that
            # non-initialized weights throw errors nevertheless
            base_class_copy._init_weights = _mock_init_weights
            base_class_copy.init_weights = _mock_all_init_weights

            model = model_class(config)
            state_dict = model.state_dict()

            def check_equal(loaded):
                for key in state_dict.keys():
                    max_diff = torch.max(
                        state_dict()[key] ^ loaded[key]
                        if isinstance(state_dict[key], torch.BoolTensor)
                        else torch.abs(state_dict[key] - loaded[key])
                    ).item()
                    self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")

            # check that certain keys didn't get saved with the model
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
                torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
                check_equal(load_state_dict(pt_checkpoint_path))
                torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
                check_equal(load_state_dict(pt_checkpoint_path))

Patrick von Platen's avatar
Patrick von Platen committed
631
    def test_initialization(self):
632
633
634
635
636
637
638
639
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        configs_no_init = _config_zero_init(config)
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.assertIn(
Lysandre Debut's avatar
Lysandre Debut committed
640
                        ((param.data.mean() * 1e9).round() / 1e9).item(),
641
                        [0.0, 1.0],
642
                        msg=f"Parameter {name} of model {model_class} seems not properly initialized",
643
                    )
thomwolf's avatar
thomwolf committed
644

Patrick von Platen's avatar
Patrick von Platen committed
645
    def test_determinism(self):
646
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
647
648
649
650
651
652
653
654
655

        def check_determinism(first, second):
            out_1 = first.cpu().numpy()
            out_2 = second.cpu().numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

656
657
658
659
660
        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
661
662
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
Weizhen's avatar
Weizhen committed
663

664
665
666
667
668
            if isinstance(first, tuple) and isinstance(second, tuple):
                for tensor1, tensor2 in zip(first, second):
                    check_determinism(tensor1, tensor2)
            else:
                check_determinism(first, second)
669

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    def test_batching_equivalence(self):
        """
        Tests that the model supports batching and that the output is the nearly the same for the same input in
        different batch sizes.
        (Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
        different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
        """

        def get_tensor_equivalence_function(batched_input):
            # models operating on continuous spaces have higher abs difference than LMs
            # instead, we can rely on cos distance for image/speech models, similar to `diffusers`
            if "input_ids" not in batched_input:
                return lambda tensor1, tensor2: (
                    1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
                )
            return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))

        def recursive_check(batched_object, single_row_object, model_name, key):
            if isinstance(batched_object, (list, tuple)):
                for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
                    recursive_check(batched_object_value, single_row_object_value, model_name, key)
            elif isinstance(batched_object, dict):
                for batched_object_value, single_row_object_value in zip(
                    batched_object.values(), single_row_object.values()
                ):
                    recursive_check(batched_object_value, single_row_object_value, model_name, key)
696
697
            # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
            elif batched_object is None or not isinstance(batched_object, torch.Tensor):
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                return
            elif batched_object.dim() == 0:
                return
            else:
                # indexing the first element does not always work
                # e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
                slice_ids = [slice(0, index) for index in single_row_object.shape]
                batched_row = batched_object[slice_ids]
                self.assertFalse(
                    torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
                )
                self.assertFalse(
                    torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
                )
                self.assertFalse(
                    torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
                )
                self.assertFalse(
                    torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
                )
                self.assertTrue(
                    (equivalence(batched_row, single_row_object)) <= 1e-03,
                    msg=(
                        f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
                        f"Difference={equivalence(batched_row, single_row_object)}."
                    ),
                )

        config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
        equivalence = get_tensor_equivalence_function(batched_input)

        for model_class in self.all_model_classes:
            config.output_hidden_states = True

            model_name = model_class.__name__
            if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
                config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
            batched_input_prepared = self._prepare_for_class(batched_input, model_class)
            model = model_class(config).to(torch_device).eval()

            batch_size = self.model_tester.batch_size
            single_row_input = {}
            for key, value in batched_input_prepared.items():
                if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
                    # e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
                    single_batch_shape = value.shape[0] // batch_size
                    single_row_input[key] = value[:single_batch_shape]
                else:
                    single_row_input[key] = value

            with torch.no_grad():
                model_batched_output = model(**batched_input_prepared)
                model_row_output = model(**single_row_input)

            if isinstance(model_batched_output, torch.Tensor):
                model_batched_output = {"model_output": model_batched_output}
                model_row_output = {"model_output": model_row_output}

            for key in model_batched_output:
                # DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan`
                if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key:
                    model_batched_output[key] = model_batched_output[key][1:]
                    model_row_output[key] = model_row_output[key][1:]
                recursive_check(model_batched_output[key], model_row_output[key], model_name, key)

763
    def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
764
        if not self.model_tester.is_training:
amyeroberts's avatar
amyeroberts committed
765
            self.skipTest(reason="ModelTester is not configured to run training tests")
766
767

        for model_class in self.all_model_classes:
768
769
            if (
                model_class.__name__
770
771
772
773
                in [
                    *get_values(MODEL_MAPPING_NAMES),
                    *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
                ]
774
775
                or not model_class.supports_gradient_checkpointing
            ):
776
                continue
777

778
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
779
780
            config.use_cache = False
            config.return_dict = True
781
            model = model_class(config)
782

783
            model.to(torch_device)
784
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
785
            model.train()
786
787
788
789
790
791
792

            # unfreeze additional layers
            for p in model.parameters():
                p.requires_grad_(True)

            optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

793
794
795
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()
796
            optimizer.step()
797

798
799
800
801
802
            for k, v in model.named_parameters():
                if v.requires_grad:
                    self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")

    def test_training(self):
803
        if not self.model_tester.is_training:
amyeroberts's avatar
amyeroberts committed
804
            self.skipTest(reason="ModelTester is not configured to run training tests")
805
806

        for model_class in self.all_model_classes:
807
808
809
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            config.return_dict = True

810
811
812
813
            if model_class.__name__ in [
                *get_values(MODEL_MAPPING_NAMES),
                *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
            ]:
814
                continue
815

816
817
818
819
820
821
822
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()

823
824
825
826
827
828
829
830
831
832
833
834
835
    def test_training_gradient_checkpointing(self):
        # Scenario - 1 default behaviour
        self.check_training_gradient_checkpointing()

    def test_training_gradient_checkpointing_use_reentrant(self):
        # Scenario - 2 with `use_reentrant=True` - this is the default value that is used in pytorch's
        # torch.utils.checkpoint.checkpoint
        self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": True})

    def test_training_gradient_checkpointing_use_reentrant_false(self):
        # Scenario - 3 with `use_reentrant=False` pytorch suggests users to use this value for
        # future releases: https://pytorch.org/docs/stable/checkpoint.html
        self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": False})
836

Patrick von Platen's avatar
Patrick von Platen committed
837
    def test_attention_outputs(self):
838
839
840
        if not self.has_attentions:
            self.skipTest(reason="Model does not output attentions")

841
842
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True
843

844
845
846
847
848
849
850
851
852
853
854
855
        seq_len = getattr(self.model_tester, "seq_length", None)
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
        chunk_length = getattr(self.model_tester, "chunk_length", None)
        if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes

        for model_class in self.all_model_classes:
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = False
856
            config.return_dict = True
857
858
859
860
861
862
863
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
864

865
866
867
868
869
870
871
872
873
874
            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
thomwolf's avatar
thomwolf committed
875

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
            if chunk_length is not None:
                self.assertListEqual(
                    list(attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
            out_len = len(outputs)

            if self.is_encoder_decoder:
                correct_outlen = 5

                # loss is at first position
                if "labels" in inputs_dict:
                    correct_outlen += 1  # loss is added to beginning
                # Question Answering model returns start_logits and end_logits
895
896
897
                if model_class.__name__ in [
                    *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
                    *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
898
                ]:
899
900
901
902
903
904
905
906
907
908
909
910
911
912
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output
                if "past_key_values" in outputs:
                    correct_outlen += 1  # past_key_values have been returned

                self.assertEqual(out_len, correct_outlen)

                # decoder attentions
                decoder_attentions = outputs.decoder_attentions
                self.assertIsInstance(decoder_attentions, (list, tuple))
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
                )
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
            inputs_dict["output_hidden_states"] = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))

            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
            self.assertEqual(out_len + added_hidden_states, len(outputs))

            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions

            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
            if chunk_length is not None:
                self.assertListEqual(
                    list(self_attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(self_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
thomwolf's avatar
thomwolf committed
957

958
    @slow
959
    def test_torchscript_simple(self):
960
961
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
962

963
    @slow
Patrick von Platen's avatar
Patrick von Platen committed
964
    def test_torchscript_output_attentions(self):
965
966
967
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_attentions = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
968

969
    @slow
Patrick von Platen's avatar
Patrick von Platen committed
970
    def test_torchscript_output_hidden_state(self):
971
972
973
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
974

975
976
977
978
    # This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry`
    def clear_torch_jit_class_registry(self):
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
979
980
981
        # torch 1.8 has no `_clear_class_state` in `torch.jit._state`
        if hasattr(torch.jit._state, "_clear_class_state"):
            torch.jit._state._clear_class_state()
982

983
    def _create_and_check_torchscript(self, config, inputs_dict):
Patrick von Platen's avatar
Patrick von Platen committed
984
        if not self.test_torchscript:
amyeroberts's avatar
amyeroberts committed
985
            self.skipTest(reason="test_torchscript is set to `False`")
986

987
988
989
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        configs_no_init.torchscript = True
        for model_class in self.all_model_classes:
990
            for attn_implementation in ["eager", "sdpa"]:
991
                if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
992
                    continue
993

994
995
996
997
998
                configs_no_init._attn_implementation = attn_implementation
                model = model_class(config=configs_no_init)
                model.to(torch_device)
                model.eval()
                inputs = self._prepare_for_class(inputs_dict, model_class)
thomwolf's avatar
thomwolf committed
999

1000
                main_input_name = model_class.main_input_name
thomwolf's avatar
thomwolf committed
1001

1002
                try:
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
                    if model.config.is_encoder_decoder:
                        model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
                        main_input = inputs[main_input_name]
                        attention_mask = inputs["attention_mask"]
                        decoder_input_ids = inputs["decoder_input_ids"]
                        decoder_attention_mask = inputs["decoder_attention_mask"]
                        model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
                        traced_model = torch.jit.trace(
                            model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
                        )
                    elif "bbox" in inputs and "image" in inputs:  # LayoutLMv2 requires additional inputs
                        input_ids = inputs["input_ids"]
                        bbox = inputs["bbox"]
                        image = inputs["image"].tensor
                        model(input_ids, bbox, image)
                        traced_model = torch.jit.trace(
                            model, (input_ids, bbox, image), check_trace=False
                        )  # when traced model is checked, an error is produced due to name mangling
                    elif "bbox" in inputs:  # Bros requires additional inputs (bbox)
                        input_ids = inputs["input_ids"]
                        bbox = inputs["bbox"]
                        model(input_ids, bbox)
                        traced_model = torch.jit.trace(
                            model, (input_ids, bbox), check_trace=False
                        )  # when traced model is checked, an error is produced due to name mangling
Eduardo Pacheco's avatar
Eduardo Pacheco committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
                    elif (
                        "pixel_values" in inputs and "prompt_pixel_values" in inputs and "prompt_masks" in inputs
                    ):  # SegGpt requires additional inputs
                        pixel_values = inputs["pixel_values"]
                        prompt_pixel_values = inputs["prompt_pixel_values"]
                        prompt_masks = inputs["prompt_masks"]
                        model(pixel_values, prompt_pixel_values, prompt_masks)
                        traced_model = torch.jit.trace(
                            model, (pixel_values, prompt_pixel_values, prompt_masks), check_trace=False
                        )  # when traced model is checked, an error is produced due to name mangling
1038
1039
1040
1041
1042
1043
1044
1045
1046
                    else:
                        main_input = inputs[main_input_name]

                        if model.config._attn_implementation == "sdpa":
                            trace_input = {main_input_name: main_input}

                            if "attention_mask" in inputs:
                                trace_input["attention_mask"] = inputs["attention_mask"]
                            else:
amyeroberts's avatar
amyeroberts committed
1047
                                self.skipTest(reason="testing SDPA without attention_mask is not supported")
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069

                            model(main_input, attention_mask=inputs["attention_mask"])
                            # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
                            traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
                        else:
                            model(main_input)
                            traced_model = torch.jit.trace(model, (main_input,))
                except RuntimeError:
                    self.fail("Couldn't trace module.")

                with tempfile.TemporaryDirectory() as tmp_dir_name:
                    pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")

                    try:
                        torch.jit.save(traced_model, pt_file_name)
                    except Exception:
                        self.fail("Couldn't save module.")

                    try:
                        loaded_model = torch.jit.load(pt_file_name)
                    except Exception:
                        self.fail("Couldn't load module.")
LysandreJik's avatar
LysandreJik committed
1070

1071
1072
                model.to(torch_device)
                model.eval()
thomwolf's avatar
thomwolf committed
1073

1074
1075
                loaded_model.to(torch_device)
                loaded_model.eval()
thomwolf's avatar
thomwolf committed
1076

1077
1078
                model_state_dict = model.state_dict()
                loaded_model_state_dict = loaded_model.state_dict()
1079

1080
1081
1082
1083
                non_persistent_buffers = {}
                for key in loaded_model_state_dict.keys():
                    if key not in model_state_dict.keys():
                        non_persistent_buffers[key] = loaded_model_state_dict[key]
1084

1085
1086
1087
                loaded_model_state_dict = {
                    key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
                }
1088

1089
                self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
thomwolf's avatar
thomwolf committed
1090

1091
1092
1093
1094
1095
1096
1097
                model_buffers = list(model.buffers())
                for non_persistent_buffer in non_persistent_buffers.values():
                    found_buffer = False
                    for i, model_buffer in enumerate(model_buffers):
                        if torch.equal(non_persistent_buffer, model_buffer):
                            found_buffer = True
                            break
1098

1099
1100
                    self.assertTrue(found_buffer)
                    model_buffers.pop(i)
1101

1102
1103
1104
1105
1106
1107
                models_equal = True
                for layer_name, p1 in model_state_dict.items():
                    if layer_name in loaded_model_state_dict:
                        p2 = loaded_model_state_dict[layer_name]
                        if p1.data.ne(p2.data).sum() > 0:
                            models_equal = False
thomwolf's avatar
thomwolf committed
1108

1109
                self.assertTrue(models_equal)
thomwolf's avatar
thomwolf committed
1110

1111
1112
1113
                # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
                # (Even with this call, there are still memory leak by ~0.04MB)
                self.clear_torch_jit_class_registry()
1114

1115
1116
1117
1118
1119
1120
1121
1122
    def test_torch_fx(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        self._create_and_check_torch_fx_tracing(config, inputs_dict)

    def test_torch_fx_output_loss(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        self._create_and_check_torch_fx_tracing(config, inputs_dict, output_loss=True)

1123
1124
    def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
        if not is_torch_fx_available() or not self.fx_compatible:
1125
1126
1127
            self.skipTest(
                f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx"
            )
1128
1129
1130
1131

        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        configs_no_init.return_dict = False

1132
        for model_class in self.all_model_classes:
1133
1134
1135
1136
1137
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)

1138
1139
            # We may want to test several inputs (various shapes, etc.).
            inputs_to_test = [inputs]
1140

1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
            if model.config.is_encoder_decoder:
                model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
                labels = inputs.get("labels", None)
                input_names = [
                    "attention_mask",
                    "decoder_attention_mask",
                    "decoder_input_ids",
                    "input_features",
                    "input_ids",
                    "input_values",
                ]
                if labels is not None:
                    input_names.append("labels")
            else:
                input_names = [
                    "attention_mask",
                    "bbox",
                    "input_features",
                    "input_ids",
                    "input_values",
1161
                    "inputs_embeds",
1162
1163
1164
1165
1166
                    "pixel_values",
                    "token_type_ids",
                    "visual_feats",
                    "visual_pos",
                ]
1167

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
                labels = inputs.get("labels", None)
                start_positions = inputs.get("start_positions", None)
                end_positions = inputs.get("end_positions", None)
                if labels is not None:
                    input_names.append("labels")
                if start_positions is not None:
                    input_names.append("start_positions")
                if end_positions is not None:
                    input_names.append("end_positions")

                if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
                    input_names.append("past_key_values")

                    # Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
                    if "past_key_values" not in inputs:
                        batch_size = inputs[next(iter(inputs))].shape[0]
                        num_heads = model.config.num_attention_heads
                        head_dim = model.config.hidden_size // model.config.num_attention_heads

                        cache_shape = (batch_size, num_heads, 0, head_dim)
                        empty_pkv = tuple(
                            (
                                torch.rand(cache_shape, dtype=torch.float, device=torch_device),
                                torch.rand(cache_shape, dtype=torch.float, device=torch_device),
1192
                            )
1193
1194
                            for i in range(model.config.num_hidden_layers)
                        )
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
                        cache_length = 9
                        cache_shape = (batch_size, num_heads, cache_length, head_dim)
                        non_empty_pkv = tuple(
                            (
                                torch.rand(cache_shape, dtype=torch.float, device=torch_device),
                                torch.rand(cache_shape, dtype=torch.float, device=torch_device),
                            )
                            for i in range(model.config.num_hidden_layers)
                        )
1205

1206
                        inps = copy.deepcopy(inputs_to_test[0])
1207

1208
                        inputs_to_test[0]["past_key_values"] = empty_pkv
1209

1210
1211
                        inps["past_key_values"] = non_empty_pkv
                        inputs_to_test.append(inps)
1212

1213
1214
1215
1216
                        past_mask = torch.ones(batch_size, cache_length, device=torch_device, dtype=torch.float)
                        inputs_to_test[1]["attention_mask"] = torch.cat(
                            (past_mask, inputs_to_test[1]["attention_mask"]), dim=1
                        )
1217

1218
            if "inputs_embeds" in inspect.signature(model.forward).parameters and not model.config.is_encoder_decoder:
1219
1220
1221
1222
1223
1224
1225
1226
                inputs_to_test.append(
                    {
                        "inputs_embeds": torch.rand(
                            2, 2, model.config.hidden_size, dtype=torch.float, device=torch_device
                        )
                    }
                )

1227
1228
            for inps in inputs_to_test:
                filtered_inputs = {k: v for (k, v) in inps.items() if k in input_names}
1229
                input_names_to_trace = list(filtered_inputs.keys())
1230

1231
1232
1233
1234
                if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
                    not hasattr(model.config, "problem_type") or model.config.problem_type is None
                ):
                    model.config.problem_type = "single_label_classification"
1235

1236
1237
1238
                model.config.use_cache = "past_key_values" in input_names_to_trace

                traced_model = symbolic_trace(model, input_names_to_trace)
1239

1240
1241
1242
                with torch.no_grad():
                    traced_output = traced_model(**filtered_inputs)
                    model_output = model(**filtered_inputs)
1243

1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
                def flatten_output(output):
                    flatten = []
                    for x in output:
                        if isinstance(x, (tuple, list)):
                            flatten += flatten_output(x)
                        elif not isinstance(x, torch.Tensor):
                            continue
                        else:
                            flatten.append(x)
                    return flatten
1254

1255
1256
1257
                model_output = flatten_output(model_output)
                traced_output = flatten_output(traced_output)
                num_outputs = len(model_output)
1258
1259
1260

                for i in range(num_outputs):
                    self.assertTrue(
1261
1262
                        torch.allclose(model_output[i], traced_output[i]),
                        f"traced {i}th output doesn't match model {i}th output for {model_class}",
1263
1264
                    )

1265
1266
1267
                # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
                # (Even with this call, there are still memory leak by ~0.04MB)
                self.clear_torch_jit_class_registry()
1268

Patrick von Platen's avatar
Patrick von Platen committed
1269
1270
    def test_headmasking(self):
        if not self.test_head_masking:
amyeroberts's avatar
amyeroberts committed
1271
            self.skipTest(reason="Model does not support head masking")
1272

1273
1274
1275
        global_rng.seed(42)
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        global_rng.seed()
LysandreJik's avatar
LysandreJik committed
1276

1277
        inputs_dict["output_attentions"] = True
1278
1279
1280
1281
1282
1283
        config.output_hidden_states = True
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
LysandreJik's avatar
LysandreJik committed
1284

1285
1286
1287
            # Prepare head_mask
            # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
            head_mask = torch.ones(
Lysandre's avatar
Lysandre committed
1288
1289
1290
                self.model_tester.num_hidden_layers,
                self.model_tester.num_attention_heads,
                device=torch_device,
1291
1292
1293
1294
            )
            head_mask[0, 0] = 0
            head_mask[-1, :-1] = 0
            head_mask.requires_grad_(requires_grad=True)
1295
            inputs = self._prepare_for_class(inputs_dict, model_class).copy()
1296
            inputs["head_mask"] = head_mask
1297
1298
1299
1300
1301
            if model.config.is_encoder_decoder:
                signature = inspect.signature(model.forward)
                arg_names = [*signature.parameters.keys()]
                if "decoder_head_mask" in arg_names:  # necessary diferentiation because of T5 model
                    inputs["decoder_head_mask"] = head_mask
1302
1303
                if "cross_attn_head_mask" in arg_names:
                    inputs["cross_attn_head_mask"] = head_mask
1304
            outputs = model(**inputs, return_dict=True)
1305
1306
1307
1308
1309
1310
1311
1312
1313

            # Test that we can get a gradient back for importance score computation
            output = sum(t.sum() for t in outputs[0])
            output = output.sum()
            output.backward()
            multihead_outputs = head_mask.grad

            self.assertIsNotNone(multihead_outputs)
            self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334

            def check_attentions_validity(attentions):
                # Remove Nan
                for t in attentions:
                    self.assertLess(
                        torch.sum(torch.isnan(t)), t.numel() / 4
                    )  # Check we don't have more than 25% nans (arbitrary)
                attentions = [
                    t.masked_fill(torch.isnan(t), 0.0) for t in attentions
                ]  # remove them (the test is less complete)

                self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
                self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
                if len(attentions) > 2:  # encoder-decoder models have only 2 layers in each module
                    self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
                self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
                self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)

            if model.config.is_encoder_decoder:
                check_attentions_validity(outputs.encoder_attentions)
                check_attentions_validity(outputs.decoder_attentions)
1335
                check_attentions_validity(outputs.cross_attentions)
1336
1337
            else:
                check_attentions_validity(outputs.attentions)
1338

Patrick von Platen's avatar
Patrick von Platen committed
1339
1340
    def test_head_pruning(self):
        if not self.test_pruning:
amyeroberts's avatar
amyeroberts committed
1341
            self.skipTest(reason="Pruning is not activated")
1342
1343

        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
1344
1345
1346
1347
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
1348

1349
1350
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
1351

1352
            inputs_dict["output_attentions"] = True
1353
1354
1355
1356
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
1357
1358
1359
1360
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
1361
1362
            model.prune_heads(heads_to_prune)
            with torch.no_grad():
1363
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1364

1365
            attentions = outputs[-1]
1366

1367
            self.assertEqual(attentions[0].shape[-3], 1)
1368
1369
            # TODO: To have this check, we will need at least 3 layers. Do we really need it?
            # self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
1370
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
LysandreJik's avatar
LysandreJik committed
1371

Patrick von Platen's avatar
Patrick von Platen committed
1372
1373
    def test_head_pruning_save_load_from_pretrained(self):
        if not self.test_pruning:
amyeroberts's avatar
amyeroberts committed
1374
            self.skipTest(reason="Pruning is not activated")
LysandreJik's avatar
LysandreJik committed
1375

1376
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
1377
1378
1379
1380
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
1381
1382
1383

            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
1384

1385
            inputs_dict["output_attentions"] = True
1386
1387
1388
1389
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
1390
1391
1392
1393
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
1394
            model.prune_heads(heads_to_prune)
1395

1396
            with tempfile.TemporaryDirectory() as temp_dir_name:
1397
1398
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
1399
                model.to(torch_device)
1400

1401
            with torch.no_grad():
1402
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1403
1404
            attentions = outputs[-1]
            self.assertEqual(attentions[0].shape[-3], 1)
1405
1406
            # TODO: To have this check, we will need at least 3 layers. Do we really need it?
            # self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
1407
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
1408

Patrick von Platen's avatar
Patrick von Platen committed
1409
1410
    def test_head_pruning_save_load_from_config_init(self):
        if not self.test_pruning:
amyeroberts's avatar
amyeroberts committed
1411
            self.skipTest(reason="Pruning is not activated")
1412

1413
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
1414
1415
1416
1417
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
1418

1419
1420
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
1421

1422
            inputs_dict["output_attentions"] = True
1423
            config.output_hidden_states = False
1424

1425
1426
1427
1428
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
1429
            config.pruned_heads = heads_to_prune
1430

1431
1432
1433
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
1434

1435
            with torch.no_grad():
1436
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1437
            attentions = outputs[-1]
1438

1439
            self.assertEqual(attentions[0].shape[-3], 1)
1440
1441
            # TODO: To have this check, we will need at least 3 layers. Do we really need it?
            # self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
1442
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
1443

Patrick von Platen's avatar
Patrick von Platen committed
1444
1445
    def test_head_pruning_integration(self):
        if not self.test_pruning:
amyeroberts's avatar
amyeroberts committed
1446
            self.skipTest(reason="Pruning is not activated")
1447

1448
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
1449
1450
1451
1452
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
1453

1454
1455
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
1456

1457
            inputs_dict["output_attentions"] = True
1458
            config.output_hidden_states = False
1459

1460
            heads_to_prune = {1: [1, 2]}
1461
            config.pruned_heads = heads_to_prune
1462

1463
1464
1465
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
1466

1467
            with torch.no_grad():
1468
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1469
            attentions = outputs[-1]
1470

1471
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 0)
1472
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
thomwolf's avatar
thomwolf committed
1473

1474
            with tempfile.TemporaryDirectory() as temp_dir_name:
1475
1476
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
1477
                model.to(torch_device)
thomwolf's avatar
thomwolf committed
1478

1479
            with torch.no_grad():
1480
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1481
            attentions = outputs[-1]
LysandreJik's avatar
LysandreJik committed
1482

1483
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 0)
1484
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
thomwolf's avatar
thomwolf committed
1485

1486
            heads_to_prune = {0: [0], 1: [1, 2]}
1487
            model.prune_heads(heads_to_prune)
1488

1489
            with torch.no_grad():
1490
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1491
            attentions = outputs[-1]
1492

1493
1494
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
1495

1496
            self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2]})
thomwolf's avatar
thomwolf committed
1497

Patrick von Platen's avatar
Patrick von Platen committed
1498
    def test_hidden_states_output(self):
Joseph Liu's avatar
Joseph Liu committed
1499
        def check_hidden_states_output(inputs_dict, config, model_class):
1500
            model = model_class(config)
1501
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
1502
            model.eval()
Joseph Liu's avatar
Joseph Liu committed
1503

thomwolf's avatar
thomwolf committed
1504
            with torch.no_grad():
1505
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1506
1507

            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
1508

Sylvain Gugger's avatar
Sylvain Gugger committed
1509
1510
1511
1512
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)
1513

Patrick von Platen's avatar
Patrick von Platen committed
1514
1515
1516
1517
1518
1519
1520
            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
                if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
                    seq_length = seq_length * self.model_tester.chunk_length
            else:
                seq_length = self.model_tester.seq_length

1521
            self.assertListEqual(
Lysandre's avatar
Lysandre committed
1522
1523
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
1524
            )
thomwolf's avatar
thomwolf committed
1525

1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

Joseph Liu's avatar
Joseph Liu committed
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)

1551
1552
1553
    def test_retain_grad_hidden_states_attentions(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
1554
        config.output_attentions = self.has_attentions
1555
1556
1557
1558
1559
1560
1561
1562
1563

        # no need to test all models as different heads yield the same functionality
        model_class = self.all_model_classes[0]
        model = model_class(config)
        model.to(torch_device)

        inputs = self._prepare_for_class(inputs_dict, model_class)

        outputs = model(**inputs)
1564

1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
        output = outputs[0]

        if config.is_encoder_decoder:
            # Seq2Seq models
            encoder_hidden_states = outputs.encoder_hidden_states[0]
            encoder_hidden_states.retain_grad()

            decoder_hidden_states = outputs.decoder_hidden_states[0]
            decoder_hidden_states.retain_grad()

1575
1576
1577
1578
1579
1580
1581
1582
1583
            if self.has_attentions:
                encoder_attentions = outputs.encoder_attentions[0]
                encoder_attentions.retain_grad()

                decoder_attentions = outputs.decoder_attentions[0]
                decoder_attentions.retain_grad()

                cross_attentions = outputs.cross_attentions[0]
                cross_attentions.retain_grad()
1584
1585
1586
1587
1588

            output.flatten()[0].backward(retain_graph=True)

            self.assertIsNotNone(encoder_hidden_states.grad)
            self.assertIsNotNone(decoder_hidden_states.grad)
1589
1590
1591
1592
1593

            if self.has_attentions:
                self.assertIsNotNone(encoder_attentions.grad)
                self.assertIsNotNone(decoder_attentions.grad)
                self.assertIsNotNone(cross_attentions.grad)
1594
1595
1596
1597
        else:
            # Encoder-/Decoder-only models
            hidden_states = outputs.hidden_states[0]
            hidden_states.retain_grad()
1598
1599
1600
1601

            if self.has_attentions:
                attentions = outputs.attentions[0]
                attentions.retain_grad()
1602
1603
1604
1605

            output.flatten()[0].backward(retain_graph=True)

            self.assertIsNotNone(hidden_states.grad)
1606
1607
1608

            if self.has_attentions:
                self.assertIsNotNone(attentions.grad)
1609

Pradhy729's avatar
Pradhy729 committed
1610
    def test_feed_forward_chunking(self):
Lysandre's avatar
Lysandre committed
1611
1612
1613
1614
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
Pradhy729's avatar
Pradhy729 committed
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
        for model_class in self.all_model_classes:
            torch.manual_seed(0)
            config = copy.deepcopy(original_config)
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

            torch.manual_seed(0)
            config.chunk_size_feed_forward = 1
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
            self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))

1633
1634
    def test_resize_position_vector_embeddings(self):
        if not self.test_resize_position_embeddings:
amyeroberts's avatar
amyeroberts committed
1635
            self.skipTest(reason="Model does not have position embeddings")
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711

        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config)
            model.to(torch_device)

            if self.model_tester.is_training is False:
                model.eval()

            max_position_embeddings = config.max_position_embeddings

            # Retrieve the embeddings and clone theme
            if model.config.is_encoder_decoder:
                encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
                encoder_cloned_embeddings = encoder_model_embed.weight.clone()
                decoder_cloned_embeddings = decoder_model_embed.weight.clone()
            else:
                model_embed = model.get_position_embeddings()
                cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the position embeddings with a larger max_position_embeddings increases
            # the model's postion embeddings size
            model.resize_position_embeddings(max_position_embeddings + 10)
            self.assertEqual(model.config.max_position_embeddings, max_position_embeddings + 10)

            # Check that it actually resizes the embeddings matrix
            if model.config.is_encoder_decoder:
                encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
                self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] + 10)
                self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] + 10)
            else:
                model_embed = model.get_position_embeddings()
                self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)

            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that resizing the position embeddings with a smaller max_position_embeddings decreases
            # the model's max_position_embeddings
            model.resize_position_embeddings(max_position_embeddings - 5)
            self.assertEqual(model.config.max_position_embeddings, max_position_embeddings - 5)

            # Check that it actually resizes the embeddings matrix
            if model.config.is_encoder_decoder:
                encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
                self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] - 5)
                self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] - 5)
            else:
                model_embed = model.get_position_embeddings()
                self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 5)

            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True

            if model.config.is_encoder_decoder:
                for p1, p2 in zip(encoder_cloned_embeddings, encoder_model_embed.weight):
                    if p1.data.ne(p2.data).sum() > 0:
                        models_equal = False
                for p1, p2 in zip(decoder_cloned_embeddings, decoder_model_embed.weight):
                    if p1.data.ne(p2.data).sum() > 0:
                        models_equal = False
            else:
                for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                    if p1.data.ne(p2.data).sum() > 0:
                        models_equal = False

            self.assertTrue(models_equal)

Patrick von Platen's avatar
Patrick von Platen committed
1712
    def test_resize_tokens_embeddings(self):
Lysandre's avatar
Lysandre committed
1713
1714
1715
1716
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
Patrick von Platen's avatar
Patrick von Platen committed
1717
        if not self.test_resize_embeddings:
amyeroberts's avatar
amyeroberts committed
1718
            self.skipTest(reason="test_resize_embeddings is set to `False`")
1719
1720
1721
1722

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config)
1723
            model.to(torch_device)
1724

Patrick von Platen's avatar
Patrick von Platen committed
1725
1726
1727
            if self.model_tester.is_training is False:
                model.eval()

1728
            model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1729
1730
1731
1732
1733
1734
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
1735
1736
1737
1738
1739
1740
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
            self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
1741
1742
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
1743
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
1744
            model(**self._prepare_for_class(inputs_dict, model_class))
1745
1746
1747

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
1748
1749
1750
1751
1752
1753
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
            self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
1754
1755
1756
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)

1757
1758
1759
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
1760
1761
1762
1763

            # make sure that decoder_input_ids are resized as well
            if "decoder_input_ids" in inputs_dict:
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
1764
            model(**self._prepare_for_class(inputs_dict, model_class))
1765

1766
1767
1768
1769
1770
1771
1772
1773
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False

            self.assertTrue(models_equal)

1774
1775
1776
1777
            config = copy.deepcopy(original_config)
            model = model_class(config)
            model.to(torch_device)

1778
            model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1779
            model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
1780
1781
1782
1783
1784
1785
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
            self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
1786
1787

            model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
1788
1789
1790
1791
1792
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
1793
1794
            self.assertTrue(model_embed.weight.shape[0] // 64, 0)

1795
1796
            self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
            self.assertTrue(new_model_vocab_size, model.vocab_size)
Arthur's avatar
Arthur committed
1797

1798
1799
1800
            model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
            self.assertTrue(model_embed.weight.shape[0] // 64, 0)

1801
1802
1803
1804
1805
            # Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
            target_dimension = 128
            model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
            self.assertTrue(model_embed.weight.shape[0], target_dimension)

1806
1807
1808
1809
1810
1811
            with self.assertRaisesRegex(
                ValueError,
                "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
            ):
                model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)

1812
1813
1814
1815
1816
1817
    def test_resize_embeddings_untied(self):
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
        if not self.test_resize_embeddings:
amyeroberts's avatar
amyeroberts committed
1818
            self.skipTest(reason="test_resize_embeddings is set to `False`")
1819
1820
1821
1822
1823

        original_config.tie_word_embeddings = False

        # if model cannot untied embeddings -> leave test
        if original_config.tie_word_embeddings:
amyeroberts's avatar
amyeroberts committed
1824
            self.skipTest(reason="Model cannot untied embeddings")
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config).to(torch_device)

            # if no output embeddings -> leave test
            if model.get_output_embeddings() is None:
                continue

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
1835
            model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
1836
            model.resize_token_embeddings(model_vocab_size + 10)
1837
1838
1839
1840
1841
1842
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
            self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
            output_embeds = model.get_output_embeddings()
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
            # Check bias if present
            if output_embeds.bias is not None:
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model.resize_token_embeddings(model_vocab_size - 15)
1853
1854
1855
1856
1857
1858
            new_model_vocab_size = (
                model.config.text_config.vocab_size
                if hasattr(model.config, "text_config")
                else model.config.vocab_size
            )
            self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
            # Check that it actually resizes the embeddings matrix
            output_embeds = model.get_output_embeddings()
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
            # Check bias if present
            if output_embeds.bias is not None:
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
            if "decoder_input_ids" in inputs_dict:
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

1873
    def test_model_get_set_embeddings(self):
1874
1875
1876
1877
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
1878
            self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
1879
1880
1881
1882
1883

            new_input_embedding_layer = nn.Embedding(10, 10)
            model.set_input_embeddings(new_input_embedding_layer)
            self.assertEqual(model.get_input_embeddings(), new_input_embedding_layer)

1884
            x = model.get_output_embeddings()
1885
            self.assertTrue(x is None or isinstance(x, nn.Linear))
1886

1887
1888
1889
1890
1891
1892
1893
    def test_model_main_input_name(self):
        for model_class in self.all_model_classes:
            model_signature = inspect.signature(getattr(model_class, "forward"))
            # The main input is the name of the argument after `self`
            observed_main_input_name = list(model_signature.parameters.keys())[1]
            self.assertEqual(model_class.main_input_name, observed_main_input_name)

1894
    def test_correct_missing_keys(self):
1895
        if not self.test_missing_keys:
amyeroberts's avatar
amyeroberts committed
1896
            self.skipTest(reason="test_missing_keys is set to `False`")
1897
1898
1899
1900
1901
1902
1903
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            base_model_prefix = model.base_model_prefix

            if hasattr(model, base_model_prefix):
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
                extra_params = {k: v for k, v in model.named_parameters() if not k.startswith(base_model_prefix)}
                extra_params.update({k: v for k, v in model.named_buffers() if not k.startswith(base_model_prefix)})
                # Some models define this as None
                if model._keys_to_ignore_on_load_missing:
                    for key in model._keys_to_ignore_on_load_missing:
                        extra_params.pop(key, None)

                if not extra_params:
                    # In that case, we *are* on a head model, but every
                    # single key is not actual parameters and this is
                    # tested in `test_tied_model_weights_key_ignore` test.
                    continue

1917
1918
1919
                with tempfile.TemporaryDirectory() as temp_dir_name:
                    model.base_model.save_pretrained(temp_dir_name)
                    model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
1920
                    self.assertGreater(len(loading_info["missing_keys"]), 0, model.__class__.__name__)
1921

1922
1923
    def test_tie_model_weights(self):
        if not self.test_torchscript:
amyeroberts's avatar
amyeroberts committed
1924
            self.skipTest(reason="test_torchscript is set to `False`")
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_same_values(layer_1, layer_2):
            equal = True
            for p1, p2 in zip(layer_1.weight, layer_2.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    equal = False
            return equal

        for model_class in self.all_model_classes:
            config.torchscript = True
            model_not_tied = model_class(config)
            if model_not_tied.get_output_embeddings() is None:
                continue

            config_tied = copy.deepcopy(config)
            config_tied.torchscript = False
            model_tied = model_class(config_tied)
            params_tied = list(model_tied.parameters())
            # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(check_same_values(embeddings, decoding))

            # Check that after resize they remain tied.
1949
1950
            vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
            model_tied.resize_token_embeddings(vocab_size + 10)
1951
1952
1953
            params_tied_2 = list(model_tied.parameters())
            self.assertEqual(len(params_tied_2), len(params_tied))

1954
1955
    @require_safetensors
    def test_can_use_safetensors(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
1956
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        for model_class in self.all_model_classes:
            model_tied = model_class(config)
            with tempfile.TemporaryDirectory() as d:
                try:
                    model_tied.save_pretrained(d, safe_serialization=True)
                except Exception as e:
                    raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")

                model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
                # Checking the state dicts are correct
                reloaded_state = model_reloaded.state_dict()
                for k, v in model_tied.state_dict().items():
                    self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
                    torch.testing.assert_close(
                        v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
1973
1974
                # Checking there was no complain of missing weights
                self.assertEqual(infos["missing_keys"], [])
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990

                # Checking the tensor sharing are correct
                ptrs = defaultdict(list)
                for k, v in model_tied.state_dict().items():
                    ptrs[v.data_ptr()].append(k)

                shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}

                for _, shared_names in shared_ptrs.items():
                    reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
                    self.assertEqual(
                        len(reloaded_ptrs),
                        1,
                        f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
                    )

Sylvain Gugger's avatar
Sylvain Gugger committed
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
    def test_load_save_without_tied_weights(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        config.tie_word_embeddings = False
        for model_class in self.all_model_classes:
            model = model_class(config)
            with tempfile.TemporaryDirectory() as d:
                model.save_pretrained(d)

                model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
                # Checking the state dicts are correct
                reloaded_state = model_reloaded.state_dict()
                for k, v in model.state_dict().items():
                    self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
                    torch.testing.assert_close(
                        v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
                    )
                # Checking there was no complain of missing weights
                self.assertEqual(infos["missing_keys"], [])

Sylvain Gugger's avatar
Sylvain Gugger committed
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
    def test_tied_weights_keys(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        config.tie_word_embeddings = True
        for model_class in self.all_model_classes:
            model_tied = model_class(config)

            ptrs = collections.defaultdict(list)
            for name, tensor in model_tied.state_dict().items():
                ptrs[id_tensor_storage(tensor)].append(name)

            # These are all the pointers of shared tensors.
            tied_params = [names for _, names in ptrs.items() if len(names) > 1]

            tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
            # Detect we get a hit for each key
            for key in tied_weight_keys:
2026
2027
                is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
                self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
Sylvain Gugger's avatar
Sylvain Gugger committed
2028
2029
2030
2031
2032
2033
2034

            # Removed tied weights found from tied params -> there should only be one left after
            for key in tied_weight_keys:
                for i in range(len(tied_params)):
                    tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]

            tied_params = [group for group in tied_params if len(group) > 1]
Sylvain Gugger's avatar
Sylvain Gugger committed
2035
2036
2037
2038
2039
            self.assertListEqual(
                tied_params,
                [],
                f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
2040

Sylvain Gugger's avatar
Sylvain Gugger committed
2041
2042
    def test_model_weights_reload_no_missing_tied_weights(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
2043
        for model_class in self.all_model_classes:
Sylvain Gugger's avatar
Sylvain Gugger committed
2044
2045
2046
            model = model_class(config)
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.save_pretrained(tmp_dir)
2047
2048
2049

                # We are nuking ALL weights on file, so every parameter should
                # yell on load. We're going to detect if we yell too much, or too little.
2050
2051
                placeholder_dict = {"tensor": torch.tensor([1, 2])}
                safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
Sylvain Gugger's avatar
Sylvain Gugger committed
2052
                model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
2053
2054
2055
2056

                prefix = f"{model_reloaded.base_model_prefix}."
                params = dict(model_reloaded.named_parameters())
                params.update(dict(model_reloaded.named_buffers()))
2057
                param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
2058
2059
2060
2061

                missing_keys = set(infos["missing_keys"])

                extra_missing = missing_keys - param_names
Sylvain Gugger's avatar
Sylvain Gugger committed
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
                # Remove tied weights from extra missing: they are normally not warned as missing if their tied
                # counterpart is present but here there are no weights at all so we do get the warning.
                ptrs = collections.defaultdict(list)
                for name, tensor in model_reloaded.state_dict().items():
                    ptrs[id_tensor_storage(tensor)].append(name)
                tied_params = [names for _, names in ptrs.items() if len(names) > 1]
                for group in tied_params:
                    group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
                    # We remove the group from extra_missing if not all weights from group are in it
                    if len(group - extra_missing) > 0:
                        extra_missing = extra_missing - set(group)
2073
2074
2075
2076

                self.assertEqual(
                    extra_missing,
                    set(),
Sylvain Gugger's avatar
Sylvain Gugger committed
2077
2078
                    f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}. "
                    f"For debugging, tied parameters are {tied_params}",
2079
2080
                )

Sylvain Gugger's avatar
Sylvain Gugger committed
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
                missed_missing = param_names - missing_keys
                # Remove nonpersistent buffers from missed_missing
                buffers = [n for n, _ in model_reloaded.named_buffers()]
                nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
                nonpersistent_buffers = {
                    k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
                }
                missed_missing = missed_missing - nonpersistent_buffers

                if model_reloaded._keys_to_ignore_on_load_missing is None:
                    expected_missing = set()
                else:
                    expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
                self.assertEqual(
                    missed_missing,
                    expected_missing,
                    f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
                    " parameters. If they are non persistent buffers make sure to instantiate them with"
                    " `persistent=False`",
                )
2101

2102
2103
2104
    def test_model_outputs_equivalence(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Sam Shleifer's avatar
Sam Shleifer committed
2105
2106
2107
2108
        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

2109
2110
2111
2112
2113
2114
2115
2116
2117
        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            with torch.no_grad():
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
                dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

                def recursive_check(tuple_object, dict_object):
                    if isinstance(tuple_object, (List, Tuple)):
                        for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                            recursive_check(tuple_iterable_value, dict_iterable_value)
NielsRogge's avatar
NielsRogge committed
2118
2119
2120
2121
2122
                    elif isinstance(tuple_object, Dict):
                        for tuple_iterable_value, dict_iterable_value in zip(
                            tuple_object.values(), dict_object.values()
                        ):
                            recursive_check(tuple_iterable_value, dict_iterable_value)
2123
2124
2125
2126
                    elif tuple_object is None:
                        return
                    else:
                        self.assertTrue(
Sam Shleifer's avatar
Sam Shleifer committed
2127
2128
2129
                            torch.allclose(
                                set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                            ),
Sylvain Gugger's avatar
Sylvain Gugger committed
2130
2131
2132
2133
2134
2135
                            msg=(
                                "Tuple and dict output are not equal. Difference:"
                                f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
                                f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
                                f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
                            ),
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
                        )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

2161
2162
2163
2164
            if self.has_attentions:
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
                dict_inputs = self._prepare_for_class(inputs_dict, model_class)
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
2165

2166
2167
2168
2169
2170
2171
2172
2173
2174
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})

                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
                check_equivalence(
                    model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
                )
2175

2176
2177
2178
2179
    # Don't copy this method to model specific test file!
    # TODO: remove this method once the issues are all fixed!
    def _make_attention_mask_non_null(self, inputs_dict):
        """Make sure no sequence has all zeros as attention mask"""
2180

2181
2182
2183
        for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
            if k in inputs_dict:
                attention_mask = inputs_dict[k]
2184

2185
2186
2187
2188
2189
2190
                # Make sure no all 0s attention masks - to avoid failure at this moment.
                # Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
                # TODO: remove this line once a fix regarding large negative values for attention mask is done.
                attention_mask = torch.cat(
                    [torch.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], dim=-1
                )
2191

2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
                # Here we make the first sequence with all 0s as attention mask.
                # Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
                # values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
                # TODO: enable this block once the large negative values thing is cleaned up.
                # (see https://github.com/huggingface/transformers/issues/14859)
                # attention_mask = torch.cat(
                #     [torch.zeros_like(attention_mask[:1], dtype=attention_mask.dtype), attention_mask[1:]],
                #     dim=0
                # )

                inputs_dict[k] = attention_mask

    # Don't copy this method to model specific test file!
    # TODO: remove this method once the issues are all fixed!
    def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
        """For temporarily ignoring some failed test cases (issues to be fixed)"""

2209
2210
        tf_keys = {k for k, v in tf_outputs.items() if v is not None}
        pt_keys = {k for k, v in pt_outputs.items() if v is not None}
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236

        key_differences = tf_keys.symmetric_difference(pt_keys)

        if model_class.__name__ in [
            "FlaubertWithLMHeadModel",
            "FunnelForPreTraining",
            "ElectraForPreTraining",
            "XLMWithLMHeadModel",
        ]:
            for k in key_differences:
                if k in ["loss", "losses"]:
                    tf_keys.discard(k)
                    pt_keys.discard(k)
        elif model_class.__name__.startswith("GPT2"):
            # `TFGPT2` has `past_key_values` as a tensor while `GPT2` has it as a tuple.
            tf_keys.discard("past_key_values")
            pt_keys.discard("past_key_values")

        # create new outputs from the remaining fields
        new_tf_outputs = type(tf_outputs)(**{k: tf_outputs[k] for k in tf_keys})
        new_pt_outputs = type(pt_outputs)(**{k: pt_outputs[k] for k in pt_keys})

        return new_tf_outputs, new_pt_outputs

    # Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
2237
        """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
2238

2239
2240
2241
2242
2243
2244
2245
2246
        Args:
            model_class: The class of the model that is currently testing. For example, `TFBertModel`,
                TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
                error messages.
            name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
            attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
                being a named field in the output.
        """
2247

2248
2249
2250
        self.assertEqual(type(name), str)
        if attributes is not None:
            self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
2251

2252
2253
2254
2255
2256
2257
        # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
        if isinstance(tf_outputs, ModelOutput):
            self.assertTrue(
                isinstance(pt_outputs, ModelOutput),
                f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
            )
2258

2259
2260
2261
            # Don't copy this block to model specific test file!
            # TODO: remove this method and this line after issues are fixed
            tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
2262

2263
2264
            tf_keys = [k for k, v in tf_outputs.items() if v is not None]
            pt_keys = [k for k, v in pt_outputs.items() if v is not None]
2265

2266
            self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
2267

2268
            # convert to the case of `tuple`
2269
            # appending each key to the current (string) `name`
2270
2271
2272
2273
            attributes = tuple([f"{name}.{k}" for k in tf_keys])
            self.check_pt_tf_outputs(
                tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
            )
2274

2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
        # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
        elif type(tf_outputs) in [tuple, list]:
            self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
            self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")

            if attributes is not None:
                # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
                self.assertEqual(
                    len(attributes),
                    len(tf_outputs),
2285
                    f"{name}: The tuple `attributes` should have the same length as `tf_outputs`",
2286
                )
2287
            else:
2288
                # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
2289
                attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
2290

2291
2292
            for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
                self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
2293

2294
2295
2296
2297
        elif isinstance(tf_outputs, tf.Tensor):
            self.assertTrue(
                isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
            )
2298

2299
2300
            tf_outputs = tf_outputs.numpy()
            pt_outputs = pt_outputs.detach().to("cpu").numpy()
2301

2302
2303
2304
            self.assertEqual(
                tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
            )
2305

2306
2307
2308
2309
            # deal with NumPy's scalars to make replacing nan values by 0 work.
            if np.isscalar(tf_outputs):
                tf_outputs = np.array([tf_outputs])
                pt_outputs = np.array([pt_outputs])
2310

2311
2312
            tf_nans = np.isnan(tf_outputs)
            pt_nans = np.isnan(pt_outputs)
2313

2314
2315
2316
2317
            pt_outputs[tf_nans] = 0
            tf_outputs[tf_nans] = 0
            pt_outputs[pt_nans] = 0
            tf_outputs[pt_nans] = 0
2318

2319
            max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
2320
            self.assertLessEqual(max_diff, tol, f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}).")
2321
2322
        else:
            raise ValueError(
2323
                "`tf_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `tf.Tensor`. Got"
Sylvain Gugger's avatar
Sylvain Gugger committed
2324
                f" {type(tf_outputs)} instead."
2325
2326
            )

2327
2328
2329
2330
    def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
        tf_inputs_dict = {}
        for key, tensor in pt_inputs_dict.items():
            # skip key that does not exist in tf
2331
            if isinstance(tensor, bool):
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
                tf_inputs_dict[key] = tensor
            elif key == "input_values":
                tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
            elif key == "pixel_values":
                tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
            elif key == "input_features":
                tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
            # other general float inputs
            elif tensor.is_floating_point():
                tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
            else:
                tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
2344

2345
        return tf_inputs_dict
2346

2347
2348
    def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
        tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
2349

2350
2351
2352
2353
        # send pytorch inputs to the correct device
        pt_inputs_dict = {
            k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
        }
2354

2355
2356
        # send pytorch model to the correct device
        pt_model.to(torch_device)
2357

2358
2359
        # Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
        pt_model.eval()
2360

2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
        with torch.no_grad():
            pt_outputs = pt_model(**pt_inputs_dict)
        tf_outputs = tf_model(tf_inputs_dict)

        # tf models returned loss is usually a tensor rather than a scalar.
        # (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
        # Change it here to a scalar to match PyTorch models' loss
        tf_loss = getattr(tf_outputs, "loss", None)
        if tf_loss is not None:
            tf_outputs.loss = tf.math.reduce_mean(tf_loss)

        self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))

    @is_pt_tf_cross_test
Matt's avatar
Matt committed
2375
    def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
2376
        import transformers
2377
2378

        for model_class in self.all_model_classes:
2379
2380
2381
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            tf_model_class_name = "TF" + model_class.__name__  # Add the "TF" at the beginning
2382
            if not hasattr(transformers, tf_model_class_name):
amyeroberts's avatar
amyeroberts committed
2383
                self.skipTest(reason="transformers does not have TF version of this model yet")
2384

2385
2386
2387
            # Output all for aggressive testing
            config.output_hidden_states = True
            config.output_attentions = self.has_attentions
2388

2389
2390
2391
2392
            # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
            # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
            # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
            self._make_attention_mask_non_null(inputs_dict)
2393
2394

            tf_model_class = getattr(transformers, tf_model_class_name)
2395
2396

            pt_model = model_class(config)
2397
2398
2399
2400
2401
2402
2403
2404
2405
            tf_model = tf_model_class(config)

            pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            pt_inputs_dict_with_labels = self._prepare_for_class(
                inputs_dict,
                model_class,
                # Not all models accept "labels" in the forward pass (yet :) )
                return_labels=True if "labels" in inspect.signature(model_class.forward).parameters.keys() else False,
            )
2406
2407
2408
2409
2410
2411
2412
2413
2414

            # make sure only tf inputs are forward that actually exist in function args
            tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())

            # remove all head masks
            tf_input_keys.discard("head_mask")
            tf_input_keys.discard("cross_attn_head_mask")
            tf_input_keys.discard("decoder_head_mask")

2415
            pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
2416
2417
2418
2419
            pt_inputs_dict_with_labels = {k: v for k, v in pt_inputs_dict_with_labels.items() if k in tf_input_keys}

            # For some models (e.g. base models), there is no label returned.
            # Set the input dict to `None` to avoid check outputs twice for the same input dicts.
2420
            if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
2421
                pt_inputs_dict_with_labels = None
2422
2423

            # Check we can load pt model in tf and vice-versa with model => model functions
2424
2425
            # Here requires `tf_inputs_dict` to build `tf_model`
            tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
Matt's avatar
Matt committed
2426
2427
2428
2429
2430
2431
            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
            )
            pt_model = transformers.load_tf2_model_in_pytorch_model(
                pt_model, tf_model, allow_missing_keys=allow_missing_keys
            )
2432

2433
2434
2435
2436
2437
            # Original test: check without `labels`
            self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
            # check with `labels`
            if pt_inputs_dict_with_labels:
                self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict_with_labels)
2438
2439
2440
2441
2442

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
Matt's avatar
Matt committed
2443
2444
2445
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
                    tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
                )
2446
2447
2448

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
Matt's avatar
Matt committed
2449
2450
2451
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
                    pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
                )
2452

2453
2454
2455
2456
2457
            # Original test: check without `labels`
            self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
            # check with `labels`
            if pt_inputs_dict_with_labels:
                self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict_with_labels)
2458
2459
2460
2461
2462

    def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
        diff = np.abs((a - b)).max()
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")

2463
    def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
2464
2465
2466
2467
2468
2469
2470
2471
2472
        """
        Args:
            model_class: The class of the model that is currently testing. For example, ..., etc.
            Currently unused, but it could make debugging easier and faster.

            names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
                Currently unused, but in the future, we could use this information to make the error message clearer
                by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
        """
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512

        self.assertEqual(type(name), str)
        if attributes is not None:
            self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")

        # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
        if isinstance(fx_outputs, ModelOutput):
            self.assertTrue(
                isinstance(pt_outputs, ModelOutput),
                f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
            )

            fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
            pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

            self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")

            # convert to the case of `tuple`
            # appending each key to the current (string) `name`
            attributes = tuple([f"{name}.{k}" for k in fx_keys])
            self.check_pt_flax_outputs(
                fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
            )

        # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
        elif type(fx_outputs) in [tuple, list]:
            self.assertEqual(
                type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
            )
            self.assertEqual(
                len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
            )

            if attributes is not None:
                # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
                self.assertEqual(
                    len(attributes),
                    len(fx_outputs),
                    f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
                )
2513
            else:
2514
2515
2516
2517
2518
2519
                # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
                attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])

            for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
                self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)

2520
        elif isinstance(fx_outputs, jnp.ndarray):
2521
2522
2523
            self.assertTrue(
                isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
            )
2524
2525
2526
2527
2528

            # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
            fx_outputs = np.array(fx_outputs)
            pt_outputs = pt_outputs.detach().to("cpu").numpy()

2529
2530
2531
2532
2533
2534
2535
2536
2537
            self.assertEqual(
                fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
            )

            # deal with NumPy's scalars to make replacing nan values by 0 work.
            if np.isscalar(fx_outputs):
                fx_outputs = np.array([fx_outputs])
                pt_outputs = np.array([pt_outputs])

2538
2539
2540
2541
2542
2543
2544
2545
            fx_nans = np.isnan(fx_outputs)
            pt_nans = np.isnan(pt_outputs)

            pt_outputs[fx_nans] = 0
            fx_outputs[fx_nans] = 0
            pt_outputs[pt_nans] = 0
            fx_outputs[pt_nans] = 0

2546
2547
2548
2549
            max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
            self.assertLessEqual(
                max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
            )
2550
2551
        else:
            raise ValueError(
2552
2553
                "`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
                f" {type(fx_outputs)} instead."
2554
2555
            )

2556
2557
2558
2559
2560
2561
2562
2563
2564
    @is_pt_flax_cross_test
    def test_equivalence_pt_to_flax(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                fx_model_class_name = "Flax" + model_class.__name__

                if not hasattr(transformers, fx_model_class_name):
amyeroberts's avatar
amyeroberts committed
2565
                    self.skipTest(reason="No Flax model exists for this class")
2566

2567
2568
2569
2570
                # Output all for aggressive testing
                config.output_hidden_states = True
                config.output_attentions = self.has_attentions

2571
2572
                fx_model_class = getattr(transformers, fx_model_class_name)

2573
2574
2575
2576
2577
2578
                # load PyTorch class
                pt_model = model_class(config).eval()
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False

2579
2580
                # load Flax class
                fx_model = fx_model_class(config, dtype=jnp.float32)
2581

2582
2583
2584
2585
2586
2587
2588
2589
2590
                # make sure only flax inputs are forward that actually exist in function args
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

                # prepare inputs
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)

                # remove function args that don't exist in Flax
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

2591
2592
2593
2594
2595
2596
                # send pytorch inputs to the correct device
                pt_inputs = {
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
                }

                # convert inputs to Flax
2597
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
2598

2599
2600
2601
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
                fx_model.params = fx_state

2602
2603
2604
                # send pytorch model to the correct device
                pt_model.to(torch_device)

2605
                with torch.no_grad():
2606
2607
                    pt_outputs = pt_model(**pt_inputs)
                fx_outputs = fx_model(**fx_inputs)
2608

2609
2610
2611
2612
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
2613
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
2614
2615
2616
2617
2618

                with tempfile.TemporaryDirectory() as tmpdirname:
                    pt_model.save_pretrained(tmpdirname)
                    fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)

2619
2620
2621
2622
2623
2624
                fx_outputs_loaded = fx_model_loaded(**fx_inputs)

                fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
2625
                self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635

    @is_pt_flax_cross_test
    def test_equivalence_flax_to_pt(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                fx_model_class_name = "Flax" + model_class.__name__

                if not hasattr(transformers, fx_model_class_name):
amyeroberts's avatar
amyeroberts committed
2636
                    self.skipTest(reason="No Flax model exists for this class")
2637

2638
2639
2640
2641
                # Output all for aggressive testing
                config.output_hidden_states = True
                config.output_attentions = self.has_attentions

2642
2643
                fx_model_class = getattr(transformers, fx_model_class_name)

2644
2645
2646
2647
2648
2649
                # load PyTorch class
                pt_model = model_class(config).eval()
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
                # So we disable `use_cache` here for PyTorch model.
                pt_model.config.use_cache = False

2650
2651
                # load Flax class
                fx_model = fx_model_class(config, dtype=jnp.float32)
2652

2653
2654
2655
2656
2657
2658
2659
2660
2661
                # make sure only flax inputs are forward that actually exist in function args
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()

                # prepare inputs
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)

                # remove function args that don't exist in Flax
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}

2662
2663
2664
2665
                # send pytorch inputs to the correct device
                pt_inputs = {
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
                }
2666

2667
                # convert inputs to Flax
2668
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
2669

2670
2671
2672
2673
2674
2675
2676
                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)

                # make sure weights are tied in PyTorch
                pt_model.tie_weights()

                # send pytorch model to the correct device
                pt_model.to(torch_device)
2677

2678
2679
2680
2681
2682
2683
2684
2685
                with torch.no_grad():
                    pt_outputs = pt_model(**pt_inputs)
                fx_outputs = fx_model(**fx_inputs)

                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
2686
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
2687
2688
2689

                with tempfile.TemporaryDirectory() as tmpdirname:
                    fx_model.save_pretrained(tmpdirname)
2690
2691
2692
                    pt_model_loaded = model_class.from_pretrained(
                        tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
                    )
2693

2694
2695
2696
2697
                # send pytorch model to the correct device
                pt_model_loaded.to(torch_device)
                pt_model_loaded.eval()

2698
                with torch.no_grad():
2699
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs)
2700

2701
2702
2703
2704
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
                pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])

                self.assertEqual(fx_keys, pt_keys)
2705
                self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
2706

Patrick von Platen's avatar
Patrick von Platen committed
2707
    def test_inputs_embeds(self):
2708
2709
2710
2711
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
2712
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
2713
            model.eval()
2714

2715
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
Weizhen's avatar
Weizhen committed
2716

2717
2718
2719
2720
2721
2722
2723
2724
2725
            if not self.is_encoder_decoder:
                input_ids = inputs["input_ids"]
                del inputs["input_ids"]
            else:
                encoder_input_ids = inputs["input_ids"]
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
                del inputs["input_ids"]
                inputs.pop("decoder_input_ids", None)

2726
2727
            wte = model.get_input_embeddings()
            if not self.is_encoder_decoder:
2728
                inputs["inputs_embeds"] = wte(input_ids)
2729
            else:
2730
2731
                inputs["inputs_embeds"] = wte(encoder_input_ids)
                inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
2732

thomwolf's avatar
thomwolf committed
2733
            with torch.no_grad():
Weizhen's avatar
Weizhen committed
2734
                model(**inputs)[0]
2735

2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
    def test_inputs_embeds_matches_input_ids(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class.__name__ not in get_values(MODEL_MAPPING_NAMES):
                continue
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            model_forward_args = inspect.signature(model.forward).parameters
            if "inputs_embeds" not in model_forward_args:
amyeroberts's avatar
amyeroberts committed
2748
                self.skipTest(reason="This model doesn't use `inputs_embeds`")
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780

            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
            pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1

            wte = model.get_input_embeddings()
            if not self.is_encoder_decoder:
                input_ids = inputs["input_ids"]
                # some models infer position ids/attn mask differently when input ids
                # by check if pad_token let's make sure no padding is in input ids
                not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
                input_ids[input_ids == pad_token_id] = not_pad_token_id
                del inputs["input_ids"]
                inputs_embeds = wte(input_ids)
                with torch.no_grad():
                    out_ids = model(input_ids=input_ids, **inputs)[0]
                    out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
            else:
                encoder_input_ids = inputs["input_ids"]
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
                encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
                decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
                del inputs["input_ids"]
                inputs.pop("decoder_input_ids", None)
                inputs_embeds = wte(encoder_input_ids)
                decoder_inputs_embeds = wte(decoder_input_ids)
                with torch.no_grad():
                    out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0]
                    out_embeds = model(
                        inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs
                    )[0]
            self.assertTrue(torch.allclose(out_embeds, out_ids))

2781
2782
    @require_torch_multi_gpu
    def test_multi_gpu_data_parallel_forward(self):
2783
2784
2785
2786
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        # some params shouldn't be scattered by nn.DataParallel
        # so just remove them if they are present.
2787
        blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
        for k in blacklist_non_batched_params:
            inputs_dict.pop(k, None)

        # move input tensors to cuda:O
        for k, v in inputs_dict.items():
            if torch.is_tensor(v):
                inputs_dict[k] = v.to(0)

        for model_class in self.all_model_classes:
            model = model_class(config=config)
            model.to(0)
            model.eval()

            # Wrap model in nn.DataParallel
2802
            model = nn.DataParallel(model)
2803
            with torch.no_grad():
2804
                _ = model(**self._prepare_for_class(inputs_dict, model_class))
2805

2806
2807
2808
    @require_torch_multi_gpu
    def test_model_parallelization(self):
        if not self.test_model_parallel:
amyeroberts's avatar
amyeroberts committed
2809
            self.skipTest(reason="test_model_parallel is set to False")
2810

2811
        # a candidate for testing_utils
2812
        def get_current_gpu_memory_use():
Patrick von Platen's avatar
Patrick von Platen committed
2813
            """returns a list of cuda memory allocations per GPU in MBs"""
2814
2815
2816
2817
2818

            per_device_memory = []
            for id in range(torch.cuda.device_count()):
                with torch.cuda.device(id):
                    per_device_memory.append(torch.cuda.memory_allocated() >> 20)
2819
2820
2821
2822
2823
2824
2825
2826
2827

            return per_device_memory

        # Needs a large model to see the difference.
        config = self.model_tester.get_large_model_config()

        for model_class in self.all_parallelizable_model_classes:
            torch.cuda.empty_cache()

2828
2829
2830
            # 1. single gpu memory load + unload + memory measurements
            # Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests)
            memory_at_start = get_current_gpu_memory_use()
2831

2832
2833
            # Put model on device 0 and take a memory snapshot
            model = model_class(config)
2834
2835
2836
            model.to("cuda:0")
            memory_after_model_load = get_current_gpu_memory_use()

2837
2838
2839
            # The memory use on device 0 should be higher than it was initially.
            self.assertGreater(memory_after_model_load[0], memory_at_start[0])

2840
            del model
2841
            gc.collect()
2842
2843
            torch.cuda.empty_cache()

2844
2845
2846
            # 2. MP test
            # it's essential to re-calibrate the usage before the next stage
            memory_at_start = get_current_gpu_memory_use()
2847
2848

            # Spread model layers over multiple devices
2849
            model = model_class(config)
2850
2851
2852
2853
            model.parallelize()
            memory_after_parallelization = get_current_gpu_memory_use()

            # Assert that the memory use on all devices is higher than it was when loaded only on CPU
2854
            for n in range(len(model.device_map.keys())):
2855
                self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
2856

2857
            # Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
2858
2859
            self.assertLess(memory_after_parallelization[0], memory_after_model_load[0])

2860
2861
            # Assert that the memory use of device 1 is higher than it was when the entire model was loaded
            # on device 0 and device 1 wasn't used at all
2862
2863
2864
            self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1])

            del model
2865
            gc.collect()
2866
2867
2868
2869
2870
            torch.cuda.empty_cache()

    @require_torch_multi_gpu
    def test_model_parallel_equal_results(self):
        if not self.test_model_parallel:
amyeroberts's avatar
amyeroberts committed
2871
            self.skipTest(reason="test_model_parallel is set to False")
2872
2873
2874
2875
2876
2877

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_parallelizable_model_classes:
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)

2878
            def cast_to_device(dictionary, device):
2879
2880
2881
                output = {}
                for k, v in dictionary.items():
                    if isinstance(v, torch.Tensor):
2882
                        output[k] = v.to(device)
2883
2884
2885
2886
2887
                    else:
                        output[k] = v

                return output

2888
2889
2890
2891
2892
2893
            model = model_class(config)
            output = model(**cast_to_device(inputs_dict, "cpu"))

            model.parallelize()

            parallel_output = model(**cast_to_device(inputs_dict, "cuda:0"))
2894
2895
2896
2897
2898
2899
2900
2901

            for value, parallel_value in zip(output, parallel_output):
                if isinstance(value, torch.Tensor):
                    self.assertTrue(torch.allclose(value, parallel_value.to("cpu"), atol=1e-7))
                elif isinstance(value, (Tuple, List)):
                    for value_, parallel_value_ in zip(value, parallel_value):
                        self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))

2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
    def check_device_map_is_respected(self, model, device_map):
        for param_name, param in model.named_parameters():
            # Find device in device_map
            while len(param_name) > 0 and param_name not in device_map:
                param_name = ".".join(param_name.split(".")[:-1])
            if param_name not in device_map:
                raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")

            param_device = device_map[param_name]
            if param_device in ["cpu", "disk"]:
                self.assertEqual(param.device, torch.device("meta"))
2913
2914
            elif param_device in ["mps"]:
                self.assertEqual(param.device, torch.device("mps"))
2915
            else:
2916
2917
                # when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
                self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
2918

Sylvain Gugger's avatar
Sylvain Gugger committed
2919
    @require_accelerate
2920
    @mark.accelerate_tests
Sylvain Gugger's avatar
Sylvain Gugger committed
2921
    @require_torch_gpu
2922
    def test_disk_offload_bin(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
2923
2924
2925
2926
2927
2928
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class._no_split_modules is None:
                continue

2929
            inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
Sylvain Gugger's avatar
Sylvain Gugger committed
2930
2931
            model = model_class(config).eval()
            model = model.to(torch_device)
2932
            torch.manual_seed(0)
2933
            base_output = model(**inputs_dict_class)
Sylvain Gugger's avatar
Sylvain Gugger committed
2934
2935
2936

            model_size = compute_module_sizes(model)[""]
            with tempfile.TemporaryDirectory() as tmp_dir:
2937
                model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
Sylvain Gugger's avatar
Sylvain Gugger committed
2938
2939

                with self.assertRaises(ValueError):
Yih-Dar's avatar
Yih-Dar committed
2940
2941
                    max_size = int(self.model_split_percents[0] * model_size)
                    max_memory = {0: max_size, "cpu": max_size}
Sylvain Gugger's avatar
Sylvain Gugger committed
2942
2943
2944
                    # This errors out cause it's missing an offload folder
                    new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

Yih-Dar's avatar
Yih-Dar committed
2945
2946
                max_size = int(self.model_split_percents[1] * model_size)
                max_memory = {0: max_size, "cpu": max_size}
Sylvain Gugger's avatar
Sylvain Gugger committed
2947
2948
2949
2950
2951
                new_model = model_class.from_pretrained(
                    tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
                )

                self.check_device_map_is_respected(new_model, new_model.hf_device_map)
2952
                torch.manual_seed(0)
2953
                new_output = new_model(**inputs_dict_class)
Sylvain Gugger's avatar
Sylvain Gugger committed
2954

2955
2956
2957
2958
                if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
                    self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
                else:
                    self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
Sylvain Gugger's avatar
Sylvain Gugger committed
2959

2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
    @require_accelerate
    @mark.accelerate_tests
    @require_torch_gpu
    def test_disk_offload_safetensors(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class._no_split_modules is None:
                continue

            inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config).eval()
            model = model.to(torch_device)
            torch.manual_seed(0)
            base_output = model(**inputs_dict_class)

            model_size = compute_module_sizes(model)[""]
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.cpu().save_pretrained(tmp_dir)

                max_size = int(self.model_split_percents[1] * model_size)
                max_memory = {0: max_size, "cpu": max_size}

                # This doesn't error out as it's in safetensors and doesn't need an offload folder
                new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

                self.check_device_map_is_respected(new_model, new_model.hf_device_map)
                torch.manual_seed(0)
                new_output = new_model(**inputs_dict_class)

2990
2991
2992
2993
                if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
                    self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
                else:
                    self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
2994

2995
    @require_accelerate
2996
    @mark.accelerate_tests
2997
2998
2999
3000
3001
3002
3003
3004
    @require_torch_gpu
    def test_cpu_offload(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class._no_split_modules is None:
                continue

3005
            inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
3006
3007
            model = model_class(config).eval()
            model = model.to(torch_device)
3008
3009

            torch.manual_seed(0)
3010
            base_output = model(**inputs_dict_class)
3011
3012
3013

            model_size = compute_module_sizes(model)[""]
            # We test several splits of sizes to make sure it works.
Yih-Dar's avatar
Yih-Dar committed
3014
            max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.cpu().save_pretrained(tmp_dir)

                for max_size in max_gpu_sizes:
                    max_memory = {0: max_size, "cpu": model_size * 2}
                    new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
                    # Making sure part of the model will actually end up offloaded
                    self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})

                    self.check_device_map_is_respected(new_model, new_model.hf_device_map)
3025
3026

                    torch.manual_seed(0)
3027
                    new_output = new_model(**inputs_dict_class)
3028

3029
3030
3031
3032
                    if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
                        self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
                    else:
                        self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
3033
3034

    @require_accelerate
3035
    @mark.accelerate_tests
3036
    @require_torch_multi_accelerator
3037
3038
3039
3040
3041
3042
3043
    def test_model_parallelism(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if model_class._no_split_modules is None:
                continue

3044
            inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
3045
3046
            model = model_class(config).eval()
            model = model.to(torch_device)
3047
3048

            torch.manual_seed(0)
3049
            base_output = model(**inputs_dict_class)
3050
3051
3052

            model_size = compute_module_sizes(model)[""]
            # We test several splits of sizes to make sure it works.
3053
            max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
3054
3055
3056
3057
3058
3059
3060
3061
3062
            with tempfile.TemporaryDirectory() as tmp_dir:
                model.cpu().save_pretrained(tmp_dir)

                for max_size in max_gpu_sizes:
                    max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
                    new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
                    # Making sure part of the model will actually end up offloaded
                    self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
                    self.check_device_map_is_respected(new_model, new_model.hf_device_map)
3063
3064

                    torch.manual_seed(0)
3065
                    new_output = new_model(**inputs_dict_class)
3066

3067
3068
3069
3070
                    if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
                        self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
                    else:
                        self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
3071

3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
    def test_problem_types(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        problem_types = [
            {"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
            {"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
            {"title": "regression", "num_labels": 1, "dtype": torch.float},
        ]

        for model_class in self.all_model_classes:
3082
3083
3084
            if model_class.__name__ not in [
                *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES),
                *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
3085
            ]:
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
                continue

            for problem_type in problem_types:
                with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
                    config.problem_type = problem_type["title"]
                    config.num_labels = problem_type["num_labels"]

                    model = model_class(config)
                    model.to(torch_device)
                    model.train()

                    inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)

                    if problem_type["num_labels"] > 1:
                        inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])

                    inputs["labels"] = inputs["labels"].to(problem_type["dtype"])

3104
3105
3106
3107
3108
3109
                    # This tests that we do not trigger the warning form PyTorch "Using a target size that is different
                    # to the input size. This will likely lead to incorrect results due to broadcasting. Please ensure
                    # they have the same size." which is a symptom something in wrong for the regression problem.
                    # See https://github.com/huggingface/transformers/issues/11780
                    with warnings.catch_warnings(record=True) as warning_list:
                        loss = model(**inputs).loss
3110
3111
3112
3113
3114
                    for w in warning_list:
                        if "Using a target size that is different to the input size" in str(w.message):
                            raise ValueError(
                                f"Something is going wrong in the regression problem: intercepted {w.message}"
                            )
3115

3116
3117
                    loss.backward()

3118
    def test_load_with_mismatched_shapes(self):
3119
        if not self.test_mismatched_shapes:
amyeroberts's avatar
amyeroberts committed
3120
            self.skipTest(reason="test_missmatched_shapes is set to False")
3121
3122
3123
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
3124
            if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
3125
3126
3127
3128
3129
3130
3131
3132
                continue

            with self.subTest(msg=f"Testing {model_class}"):
                with tempfile.TemporaryDirectory() as tmp_dir:
                    model = model_class(config)
                    model.save_pretrained(tmp_dir)

                    # Fails when we don't set ignore_mismatched_sizes=True
3133
                    with self.assertRaises(RuntimeError):
3134
                        new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
3135
3136
                    with self.assertRaises(RuntimeError):
                        new_model_without_prefix = AutoModel.from_pretrained(tmp_dir, vocab_size=10)
3137
3138

                    logger = logging.get_logger("transformers.modeling_utils")
3139

3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
                    with CaptureLogger(logger) as cl:
                        new_model = AutoModelForSequenceClassification.from_pretrained(
                            tmp_dir, num_labels=42, ignore_mismatched_sizes=True
                        )
                    self.assertIn("the shapes did not match", cl.out)
                    new_model.to(torch_device)
                    inputs = self._prepare_for_class(inputs_dict, model_class)
                    logits = new_model(**inputs).logits
                    self.assertEqual(logits.shape[1], 42)

3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
                    with CaptureLogger(logger) as cl:
                        new_model_without_prefix = AutoModel.from_pretrained(
                            tmp_dir, vocab_size=10, ignore_mismatched_sizes=True
                        )
                    self.assertIn("the shapes did not match", cl.out)
                    input_ids = ids_tensor((2, 8), 10)
                    new_model_without_prefix.to(torch_device)
                    if self.is_encoder_decoder:
                        new_model_without_prefix(input_ids, decoder_input_ids=input_ids)
                    else:
                        new_model_without_prefix(input_ids)

3162
3163
    def test_mismatched_shapes_have_properly_initialized_weights(self):
        if not self.test_mismatched_shapes:
amyeroberts's avatar
amyeroberts committed
3164
            self.skipTest(reason="test_missmatched_shapes is set to False")
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        configs_no_init = _config_zero_init(config)

        for model_class in self.all_model_classes:
            if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
                continue

            with self.subTest(msg=f"Testing {model_class}"):
                with tempfile.TemporaryDirectory() as tmp_dir:
                    model = model_class(configs_no_init)
                    model.save_pretrained(tmp_dir)

                    # Fails when we don't set ignore_mismatched_sizes=True
                    with self.assertRaises(RuntimeError):
                        new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)

                    logger = logging.get_logger("transformers.modeling_utils")

                    with CaptureLogger(logger) as cl:
                        new_model = AutoModelForSequenceClassification.from_pretrained(
                            tmp_dir, num_labels=42, ignore_mismatched_sizes=True
                        )
                    self.assertIn("the shapes did not match", cl.out)

                    for name, param in new_model.named_parameters():
                        if param.requires_grad:
                            self.assertIn(
                                ((param.data.mean() * 1e9).round() / 1e9).item(),
                                [0.0, 1.0],
                                msg=f"Parameter {name} of model {model_class} seems not properly initialized",
                            )

    def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
        # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
        class MyClass(PreTrainedModel):
            config_class = PretrainedConfig

            def __init__(self, config=None):
                super().__init__(config if config is not None else PretrainedConfig())
                self.linear = nn.Linear(10, config.num_labels, bias=True)
                self.embedding = nn.Embedding(10, 10)
                self.std = 1

            def _init_weights(self, module):
                if isinstance(module, nn.Linear):
                    module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
                    if module.bias is not None:
                        module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std)

        # Used to make sure the weights with matched shape are loaded correctly
        config = PretrainedConfig()
        config.num_labels = 3
        model = MyClass(config=config)

        # Used to make sure the weights with mismatched shape are properly initialized
        set_seed(0)
        config = PretrainedConfig()
        config.num_labels = 4
        # not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the
        # same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below
        # for `linear` part.
        with ContextManagers([no_init_weights(True)]):
            target_model = MyClass(config=config)
        target_model.apply(target_model._initialize_weights)

        with tempfile.TemporaryDirectory() as tmpdirname:
            state_dict = model.state_dict()
            del state_dict["linear.weight"]

            model.config.save_pretrained(tmpdirname)
            torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))

            set_seed(0)
            new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True)

            for key in new_model.state_dict().keys():
                # check weight values for weights with matched shapes are identical
                # (i.e. correctly loaded from the checkpoint)
                if key not in ["linear.weight", "linear.bias"]:
                    max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key]))
                    self.assertLessEqual(
                        max_diff.item(),
                        1e-6,
                        msg=f"the weight values for `{key}` in `new_model` and `model` are  not identical",
                    )
                else:
                    # check we have some mismatched shapes
                    self.assertNotEqual(
                        model.state_dict()[key].shape,
                        new_model.state_dict()[key].shape,
                        msg=f"the weight shapes for {key} in `model` and `new_model` should differ",
                    )
                    # check the weights with mismatched shape are properly initialized
                    max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key]))
                    self.assertLessEqual(
                        max_diff.item(),
                        1e-6,
                        msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical",
                    )

3266
3267
3268
3269
3270
3271
3272
3273
3274
    def test_model_is_small(self):
        # Just a consistency check to make sure we are not running tests on 80M parameter models.
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            num_params = model.num_parameters()
            assert (
                num_params < 1000000
3275
            ), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
3276

3277
3278
3279
3280
3281
    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_conversion(self):
amyeroberts's avatar
amyeroberts committed
3282
3283
3284
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3285
3286
3287
3288
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            if not model_class._supports_flash_attn_2:
3289
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
3290
3291
3292
3293
3294
3295

            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(
3296
                    tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
                ).to(torch_device)

                for _, module in model.named_modules():
                    if "FlashAttention" in module.__class__.__name__:
                        return

                self.assertTrue(False, "FlashAttention2 modules not found in model")

    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
3309
    @is_flaky()
Yoach Lacombe's avatar
Yoach Lacombe committed
3310
    def test_flash_attn_2_inference_equivalence(self):
amyeroberts's avatar
amyeroberts committed
3311
3312
3313
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3314
3315
        for model_class in self.all_model_classes:
            if not model_class._supports_flash_attn_2:
3316
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
3317

3318
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
3319
3320
3321
3322
3323
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model_fa = model_class.from_pretrained(
3324
                    tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
3325
3326
3327
                )
                model_fa.to(torch_device)

Yoach Lacombe's avatar
Yoach Lacombe committed
3328
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
3329
3330
                model.to(torch_device)

3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
                dummy_input = inputs_dict[model.main_input_name][:1]
                if dummy_input.dtype in [torch.float32, torch.float16]:
                    dummy_input = dummy_input.to(torch.bfloat16)

                dummy_attention_mask = inputs_dict.get("attention_mask", None)

                if dummy_attention_mask is not None:
                    dummy_attention_mask = dummy_attention_mask[:1]
                    dummy_attention_mask[:, 1:] = 1
                    dummy_attention_mask[:, :1] = 0
3341

3342
3343
3344
3345
3346
3347
3348
3349
                if model.config.is_encoder_decoder:
                    decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]

                    outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
                    outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
                else:
                    outputs = model(dummy_input, output_hidden_states=True)
                    outputs_fa = model_fa(dummy_input, output_hidden_states=True)
3350

3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
                logits = (
                    outputs.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs.decoder_hidden_states[-1]
                )
                logits_fa = (
                    outputs_fa.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs_fa.decoder_hidden_states[-1]
                )
3361

3362
                assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
3363

3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
                if model.config.is_encoder_decoder:
                    other_inputs = {
                        "decoder_input_ids": decoder_input_ids,
                        "decoder_attention_mask": dummy_attention_mask,
                        "output_hidden_states": True,
                    }
                    if dummy_attention_mask is not None:
                        other_inputs["attention_mask"] = dummy_attention_mask

                    outputs = model(dummy_input, **other_inputs)
                    outputs_fa = model_fa(dummy_input, **other_inputs)
                else:
                    other_inputs = {
                        "output_hidden_states": True,
                    }
                    if dummy_attention_mask is not None:
                        other_inputs["attention_mask"] = dummy_attention_mask

                    outputs = model(dummy_input, **other_inputs)
                    outputs_fa = model_fa(dummy_input, **other_inputs)

                logits = (
                    outputs.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs.decoder_hidden_states[-1]
                )
                logits_fa = (
                    outputs_fa.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs_fa.decoder_hidden_states[-1]
                )
3395

3396
                assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
3397

3398
3399
                # check with inference + dropout
                model.train()
3400
                _ = model_fa(dummy_input, **other_inputs)
3401

3402
3403
3404
3405
    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
3406
    @is_flaky()
Yoach Lacombe's avatar
Yoach Lacombe committed
3407
    def test_flash_attn_2_inference_equivalence_right_padding(self):
amyeroberts's avatar
amyeroberts committed
3408
3409
3410
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3411
3412
        for model_class in self.all_model_classes:
            if not model_class._supports_flash_attn_2:
3413
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
3414

3415
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
3416
3417
3418
3419
3420
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model_fa = model_class.from_pretrained(
3421
                    tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
3422
3423
3424
                )
                model_fa.to(torch_device)

Yoach Lacombe's avatar
Yoach Lacombe committed
3425
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
3426
3427
                model.to(torch_device)

3428
3429
3430
3431
3432
                dummy_input = inputs_dict[model.main_input_name][:1]
                if dummy_input.dtype in [torch.float32, torch.float16]:
                    dummy_input = dummy_input.to(torch.bfloat16)

                dummy_attention_mask = inputs_dict.get("attention_mask", None)
3433

3434
3435
3436
3437
                if dummy_attention_mask is not None:
                    dummy_attention_mask = dummy_attention_mask[:1]
                    dummy_attention_mask[:, :-1] = 1
                    dummy_attention_mask[:, -1:] = 0
3438

3439
3440
                if model.config.is_encoder_decoder:
                    decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
3441

3442
3443
3444
3445
3446
                    outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
                    outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
                else:
                    outputs = model(dummy_input, output_hidden_states=True)
                    outputs_fa = model_fa(dummy_input, output_hidden_states=True)
3447

3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
                logits = (
                    outputs.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs.decoder_hidden_states[-1]
                )
                logits_fa = (
                    outputs_fa.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs_fa.decoder_hidden_states[-1]
                )
3458

3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
                assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)

                if model.config.is_encoder_decoder:
                    other_inputs = {
                        "decoder_input_ids": decoder_input_ids,
                        "decoder_attention_mask": dummy_attention_mask,
                        "output_hidden_states": True,
                    }
                    if dummy_attention_mask is not None:
                        other_inputs["attention_mask"] = dummy_attention_mask

                    outputs = model(dummy_input, **other_inputs)
                    outputs_fa = model_fa(dummy_input, **other_inputs)
                else:
                    other_inputs = {
                        "output_hidden_states": True,
                    }
                    if dummy_attention_mask is not None:
                        other_inputs["attention_mask"] = dummy_attention_mask

                    outputs = model(dummy_input, **other_inputs)
                    outputs_fa = model_fa(dummy_input, **other_inputs)

                logits = (
                    outputs.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs.decoder_hidden_states[-1]
                )
                logits_fa = (
                    outputs_fa.hidden_states[-1]
                    if not model.config.is_encoder_decoder
                    else outputs_fa.decoder_hidden_states[-1]
                )

                assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
3494
3495
3496
3497
3498

    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
3499
    @is_flaky()
3500
    def test_flash_attn_2_generate_left_padding(self):
amyeroberts's avatar
amyeroberts committed
3501
3502
3503
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3504
3505
        for model_class in self.all_generative_model_classes:
            if not model_class._supports_flash_attn_2:
3506
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
3507

3508
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
3509
3510
3511
3512
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
3513
3514
3515
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
                    torch_device
                )
3516

3517
3518
3519
3520
3521
3522
3523
3524
                dummy_input = inputs_dict[model.main_input_name]
                if dummy_input.dtype in [torch.float32, torch.bfloat16]:
                    dummy_input = dummy_input.to(torch.float16)

                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
                # make sure we do left padding
                dummy_attention_mask[:, :-1] = 0
                dummy_attention_mask[:, -1:] = 1
3525
3526
3527
3528
3529
3530

                out = model.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
                )

                model = model_class.from_pretrained(
3531
3532
3533
3534
                    tmpdirname,
                    torch_dtype=torch.float16,
                    attn_implementation="flash_attention_2",
                    low_cpu_mem_usage=True,
3535
3536
3537
3538
3539
3540
                ).to(torch_device)

                out_fa = model.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
                )

3541
                self.assertTrue(torch.allclose(out, out_fa))
3542
3543
3544
3545

    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
3546
    @is_flaky()
3547
3548
    @slow
    def test_flash_attn_2_generate_padding_right(self):
amyeroberts's avatar
amyeroberts committed
3549
3550
3551
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3552
3553
        for model_class in self.all_generative_model_classes:
            if not model_class._supports_flash_attn_2:
3554
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
3555

3556
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
3557
3558
3559
3560
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
3561
3562
3563
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
                    torch_device
                )
3564

3565
3566
3567
3568
3569
                dummy_input = inputs_dict[model.main_input_name]
                if dummy_input.dtype in [torch.float32, torch.bfloat16]:
                    dummy_input = dummy_input.to(torch.float16)

                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
3570
                # make sure we do right padding
3571
3572
                dummy_attention_mask[:, :-1] = 1
                dummy_attention_mask[:, -1:] = 0
3573
3574
3575
3576
3577
3578

                out = model.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
                )

                model = model_class.from_pretrained(
3579
3580
3581
3582
                    tmpdirname,
                    torch_dtype=torch.float16,
                    attn_implementation="flash_attention_2",
                    low_cpu_mem_usage=True,
3583
3584
3585
3586
3587
3588
                ).to(torch_device)

                out_fa = model.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
                )

3589
                self.assertTrue(torch.allclose(out, out_fa))
3590

3591
3592
3593
3594
    @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
    @require_torch_sdpa
    @slow
    def test_eager_matches_sdpa_inference(self, torch_dtype: str):
amyeroberts's avatar
amyeroberts committed
3595
3596
3597
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3598
3599
3600
        if not self.all_model_classes[0]._supports_sdpa:
            self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")

fxmarty's avatar
fxmarty committed
3601
3602
3603
3604
3605
3606
3607
        if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
            self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")

        if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
            self.skipTest(
                f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
            )
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623

        # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
        if torch_dtype == "float16":
            torch_dtype = torch.float16
        elif torch_dtype == "bfloat16":
            torch_dtype = torch.bfloat16
        elif torch_dtype == "float32":
            torch_dtype = torch.float32

        atols = {
            ("cpu", False, torch.float32): 1e-6,
            ("cpu", False, torch.bfloat16): 1e-2,
            ("cpu", True, torch.float32): 1e-6,
            ("cpu", True, torch.bfloat16): 1e-2,
            ("cuda", False, torch.float32): 1e-6,
            ("cuda", False, torch.bfloat16): 1e-2,
fxmarty's avatar
fxmarty committed
3624
            ("cuda", False, torch.float16): 5e-3,
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
            ("cuda", True, torch.float32): 1e-6,
            ("cuda", True, torch.bfloat16): 1e-2,
            ("cuda", True, torch.float16): 5e-3,
        }
        rtols = {
            ("cpu", False, torch.float32): 1e-4,
            ("cpu", False, torch.bfloat16): 1e-2,
            ("cpu", True, torch.float32): 1e-4,
            ("cpu", True, torch.bfloat16): 1e-2,
            ("cuda", False, torch.float32): 1e-4,
            ("cuda", False, torch.bfloat16): 1e-2,
fxmarty's avatar
fxmarty committed
3636
            ("cuda", False, torch.float16): 5e-3,
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
            ("cuda", True, torch.float32): 1e-4,
            ("cuda", True, torch.bfloat16): 3e-2,
            ("cuda", True, torch.float16): 5e-3,
        }

        def get_mean_reldiff(failcase, x, ref, atol, rtol):
            return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"

        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            model = model_class(config)
3648
3649
3650
3651
3652
            # FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
            # These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
            # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
            # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
            deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672

            is_encoder_decoder = model.config.is_encoder_decoder

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
                model_sdpa = model_sdpa.eval().to(torch_device)

                self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

                model_eager = model_class.from_pretrained(
                    tmpdirname,
                    torch_dtype=torch_dtype,
                    attn_implementation="eager",
                )
                model_eager = model_eager.eval().to(torch_device)

                self.assertTrue(model_eager.config._attn_implementation == "eager")

                for name, submodule in model_eager.named_modules():
3673
3674
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
3675
3676
3677
3678
                        raise ValueError("The eager model should not have SDPA attention layers")

                has_sdpa = False
                for name, submodule in model_sdpa.named_modules():
3679
3680
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
3681
3682
3683
3684
3685
                        has_sdpa = True
                        break
                if not has_sdpa and model_sdpa.config.model_type != "falcon":
                    raise ValueError("The SDPA model should have SDPA attention layers")

3686
                # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
3687
3688
3689
3690
                # but it would be nicer to have an efficient way to use parameterized.expand
                fail_cases = []
                for padding_side in ["left", "right"]:
                    for use_mask in [False, True]:
3691
3692
3693
3694
3695
3696
                        for output_attentions in [True, False]:
                            can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
                            if not (self.has_attentions and can_output_attn) and output_attentions:
                                continue
                            for batch_size in [1, 5]:
                                dummy_input = inputs_dict[model.main_input_name]
3697
3698

                                if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
                                    dummy_input = dummy_input.to(torch_dtype)

                                dummy_input = dummy_input[:batch_size]
                                if dummy_input.shape[0] != batch_size:
                                    if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
                                        extension = torch.rand(
                                            batch_size - dummy_input.shape[0],
                                            *dummy_input.shape[1:],
                                            dtype=torch_dtype,
                                            device=torch_device,
                                        )
                                        dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
3711
                                    else:
3712
3713
3714
3715
3716
3717
3718
                                        extension = torch.randint(
                                            high=5,
                                            size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
                                            dtype=dummy_input.dtype,
                                            device=torch_device,
                                        )
                                        dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
3719

3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
                                if not use_mask:
                                    dummy_attention_mask = None
                                else:
                                    dummy_attention_mask = inputs_dict.get("attention_mask", None)
                                    if dummy_attention_mask is None:
                                        if is_encoder_decoder:
                                            seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
                                        else:
                                            seqlen = dummy_input.shape[-1]
                                        dummy_attention_mask = (
                                            torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
                                        )

                                    dummy_attention_mask = dummy_attention_mask[:batch_size]
                                    if dummy_attention_mask.shape[0] != batch_size:
3735
                                        extension = torch.ones(
3736
3737
3738
                                            batch_size - dummy_attention_mask.shape[0],
                                            *dummy_attention_mask.shape[1:],
                                            dtype=dummy_attention_mask.dtype,
3739
3740
                                            device=torch_device,
                                        )
3741
3742
                                        dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
                                        dummy_attention_mask = dummy_attention_mask.to(torch_device)
3743

3744
                                    dummy_attention_mask[:] = 1
3745
                                    if padding_side == "left":
3746
3747
3748
3749
3750
                                        dummy_attention_mask[-1, :-1] = 1
                                        dummy_attention_mask[-1, -4:] = 0
                                    elif padding_side == "right":
                                        dummy_attention_mask[-1, 1:] = 1
                                        dummy_attention_mask[-1, :3] = 0
3751

3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
3763
                                for enable_kernels in [False, True]:
                                    failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
                                    if is_encoder_decoder:
                                        decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
                                            :batch_size
                                        ]
                                        if decoder_input_ids.shape[0] != batch_size:
                                            extension = torch.ones(
                                                batch_size - decoder_input_ids.shape[0],
                                                *decoder_input_ids.shape[1:],
                                                dtype=decoder_input_ids.dtype,
                                                device=torch_device,
3764
                                            )
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774
3775
3776
3777
3778
3779
3780
3781
3782
3783
3784
3785
3786
3787
3788
3789
                                            decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
                                            decoder_input_ids = decoder_input_ids.to(torch_device)

                                        # TODO: never an `attention_mask` arg here?
                                        processed_inputs = {
                                            model.main_input_name: dummy_input,
                                            "decoder_input_ids": decoder_input_ids,
                                            "decoder_attention_mask": dummy_attention_mask,
                                            "output_hidden_states": True,
                                        }
                                    else:
                                        processed_inputs = {
                                            model.main_input_name: dummy_input,
                                            "output_hidden_states": True,
                                        }

                                        # Otherwise fails for e.g. WhisperEncoderModel
                                        if "attention_mask" in inspect.signature(model_eager.forward).parameters:
                                            processed_inputs["attention_mask"] = dummy_attention_mask

                                        if (
                                            self.has_attentions
                                            and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
                                        ):
                                            processed_inputs["output_attentions"] = output_attentions
3790
3791
3792
3793
3794
3795
3796
3797
3798
3799
3800
3801
3802
3803
3804
3805
3806
3807
3808
3809
3810
                                    if not deactivate_mask and (
                                        "bool_masked_pos" in inspect.signature(model_eager.forward).parameters
                                    ):
                                        dummy_mask = torch.ones((self.model_tester.num_masks,))

                                        # In case of additional token (like class) we define a custom `mask_length`
                                        if hasattr(self.model_tester, "mask_length"):
                                            mask_length = self.model_tester.mask_length - dummy_mask.size(0)
                                        else:
                                            mask_length = self.model_tester.seq_length - dummy_mask.size(0)
                                        dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
                                        dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
                                        processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)

                                    if "noise" in inspect.signature(model_eager.forward).parameters:
                                        np.random.seed(2)
                                        num_patches = int(
                                            (self.model_tester.image_size // self.model_tester.patch_size) ** 2
                                        )
                                        noise = np.random.uniform(size=(batch_size, num_patches))
                                        processed_inputs["noise"] = torch.from_numpy(noise)
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832

                                    # TODO: test gradients as well (& for FA2 as well!)
                                    with torch.no_grad():
                                        with torch.backends.cuda.sdp_kernel(
                                            enable_flash=enable_kernels,
                                            enable_math=True,
                                            enable_mem_efficient=enable_kernels,
                                        ):
                                            prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
                                            outputs_eager = model_eager(**prepared_inputs)
                                            outputs_sdpa = model_sdpa(**prepared_inputs)

                                    logits_eager = (
                                        outputs_eager.hidden_states[-1]
                                        if not is_encoder_decoder
                                        else outputs_eager.decoder_hidden_states[-1]
                                    )
                                    logits_sdpa = (
                                        outputs_sdpa.hidden_states[-1]
                                        if not is_encoder_decoder
                                        else outputs_sdpa.decoder_hidden_states[-1]
                                    )
3833

3834
3835
3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
3846
3847
3848
3849
3850
3851
3852
3853
3854
3855
3856
3857
3858
3859
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
3882
                                    if torch_device in ["cpu", "cuda"]:
                                        atol = atols[torch_device, enable_kernels, torch_dtype]
                                        rtol = rtols[torch_device, enable_kernels, torch_dtype]
                                    else:
                                        atol = 1e-7
                                        rtol = 1e-4

                                    # Masked tokens output slightly deviates - we don't mind that.
                                    if use_mask:
                                        if padding_side == "left":
                                            sub_sdpa = logits_sdpa[:-1]
                                            sub_eager = logits_eager[:-1]
                                            if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                                fail_cases.append(
                                                    get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
                                                )

                                            sub_sdpa = logits_sdpa[-1, :-4]
                                            sub_eager = logits_eager[-1, :-4]
                                            if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                                fail_cases.append(
                                                    get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
                                                )

                                            # Testing the padding tokens is not really meaningful but anyway
                                            # sub_sdpa = logits_sdpa[-1, -4:]
                                            # sub_eager = logits_eager[-1, -4:]
                                            # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                            #     fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
                                        elif padding_side == "right":
                                            sub_sdpa = logits_sdpa[:-1]
                                            sub_eager = logits_eager[:-1]
                                            if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                                fail_cases.append(
                                                    get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
                                                )

                                            sub_sdpa = logits_sdpa[-1, 3:]
                                            sub_eager = logits_eager[-1, 3:]
                                            if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                                fail_cases.append(
                                                    get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
                                                )

                                            # Testing the padding tokens is not really meaningful but anyway
                                            # sub_sdpa = logits_sdpa[-1, :3]
                                            # sub_eager = logits_eager[-1, :3]
                                            # if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
                                            #     fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
3883

3884
3885
                                    else:
                                        if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
3886
                                            fail_cases.append(
3887
                                                get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
3888
3889
3890
3891
                                            )

                self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))

3892
3893
3894
3895
    @require_torch_sdpa
    @require_torch_gpu
    @slow
    def test_sdpa_can_dispatch_on_flash(self):
amyeroberts's avatar
amyeroberts committed
3896
3897
3898
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3899
3900
3901
3902
        compute_capability = torch.cuda.get_device_capability()
        major, _ = compute_capability

        if not torch.version.cuda or major < 8:
amyeroberts's avatar
amyeroberts committed
3903
            self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
3904
3905
3906
3907
3908
3909

        for model_class in self.all_model_classes:
            if not model_class._supports_sdpa:
                self.skipTest(f"{model_class.__name__} does not support SDPA")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
3910
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
Raushan Turganbay's avatar
Raushan Turganbay committed
3911
            if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
amyeroberts's avatar
amyeroberts committed
3912
3913
3914
                self.skipTest(
                    reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
                )
Pablo Montalvo's avatar
Pablo Montalvo committed
3915
3916
3917
3918
            if config.model_type in ["paligemma"]:
                self.skipTest(
                    "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
                )
3919
            if config.model_type in ["idefics"]:
amyeroberts's avatar
amyeroberts committed
3920
                self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input")
3921
3922
3923
3924
3925
3926
3927
3928
3929
3930
3931
3932
3933
3934
3935
3936
3937
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
                model.to(torch_device)

                inputs_dict.pop("attention_mask", None)
                inputs_dict.pop("decoder_attention_mask", None)

                for name, inp in inputs_dict.items():
                    if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
                        inputs_dict[name] = inp.to(torch.float16)

                with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
                    _ = model(**inputs_dict)

3938
3939
3940
3941
    @require_torch_sdpa
    @require_torch_gpu
    @slow
    def test_sdpa_can_compile_dynamic(self):
amyeroberts's avatar
amyeroberts committed
3942
3943
3944
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3945
3946
3947
3948
        compute_capability = torch.cuda.get_device_capability()
        major, _ = compute_capability

        if not torch.version.cuda or major < 8:
amyeroberts's avatar
amyeroberts committed
3949
            self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0")
3950
3951
3952
3953
3954
3955
3956
3957
3958
3959
3960
3961
3962
3963
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981

        for model_class in self.all_model_classes:
            if not model_class._supports_sdpa:
                self.skipTest(f"{model_class.__name__} does not support SDPA")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            if config.model_type in ["dbrx"]:
                self.skipTest(
                    "DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile."
                )
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")
                model.to(torch_device)

                # For PyTorch 2.1 - 2.3.0 set `dynamic=True`. In the future setting `dynamic=None` and using `torch._dynamo.mark_dynamic()`
                # on input tensors will be required. `mark_dynamic` currently raises inconsistent shape errors.
                model = torch.compile(model, dynamic=True)

                inputs_dict.pop("attention_mask", None)
                inputs_dict.pop("decoder_attention_mask", None)
                for name, inp in inputs_dict.items():
                    if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
                        inputs_dict[name] = inp.to(torch.float16)

                # use no_grad to save some memory
                with torch.no_grad():
                    _ = model(**inputs_dict)

3982
3983
3984
    @require_torch_sdpa
    @slow
    def test_eager_matches_sdpa_generate(self):
amyeroberts's avatar
amyeroberts committed
3985
3986
3987
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

3988
3989
3990
3991
3992
3993
3994
3995
3996
3997
3998
3999
4000
4001
4002
4003
4004
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
        max_new_tokens = 30

        if len(self.all_generative_model_classes) == 0:
            self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")

        for model_class in self.all_generative_model_classes:
            if not model_class._supports_sdpa:
                self.skipTest(f"{model_class.__name__} does not support SDPA")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            dummy_input = inputs_dict[model_class.main_input_name]
            if dummy_input.dtype in [torch.float32, torch.bfloat16]:
                dummy_input = dummy_input.to(torch.float16)

            # make sure that all models have enough positions for generation
            if hasattr(config, "max_position_embeddings"):
                config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)

                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))

                model_sdpa = model_class.from_pretrained(
                    tmpdirname,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                ).to(torch_device)

                self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

                model_eager = model_class.from_pretrained(
                    tmpdirname,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                    attn_implementation="eager",
                ).to(torch_device)

                self.assertTrue(model_eager.config._attn_implementation == "eager")

                for name, submodule in model_eager.named_modules():
4032
4033
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
4034
4035
4036
4037
                        raise ValueError("The eager model should not have SDPA attention layers")

                has_sdpa = False
                for name, submodule in model_sdpa.named_modules():
4038
4039
                    class_name = submodule.__class__.__name__
                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
4050
4051
4052
4053
4054
4055
                        has_sdpa = True
                        break
                if not has_sdpa:
                    raise ValueError("The SDPA model should have SDPA attention layers")

                # Just test that a large cache works as expected
                res_eager = model_eager.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
                )

                res_sdpa = model_sdpa.generate(
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
                )

                self.assertTrue(torch.allclose(res_eager, res_sdpa))

4056
4057
    @require_torch_sdpa
    def test_sdpa_matches_eager_sliding_window(self):
amyeroberts's avatar
amyeroberts committed
4058
4059
4060
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

4061
4062
4063
4064
4065
4066
4067
4068
4069
4070
4071
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
        WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]

        if len(self.all_generative_model_classes) == 0:
            self.skipTest(f"No generative model classes for {self.__class__.__name__}")

        for model_class in self.all_generative_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            if config.model_type not in WINDOW_ATTENTION_MODELS:
                self.skipTest(f"{config.model_type} does not use window attention")

            config.sliding_window = 2

            dummy_input = inputs_dict[model_class.main_input_name]
            attention_mask = inputs_dict["attention_mask"]

            self.assertTrue(dummy_input.ndim == 2)
            self.assertTrue(dummy_input.shape[1] > 6)

            with tempfile.TemporaryDirectory() as tmpdir:
                with torch.device(torch_device):
                    model_eager = AutoModelForCausalLM.from_config(
                        config, attn_implementation="eager", torch_dtype=torch.float32
                    )

                model_eager.save_pretrained(tmpdir)

                with torch.device(torch_device):
                    model_sdpa = AutoModelForCausalLM.from_pretrained(
                        tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32
                    )

                model_eager = model_eager.eval()
                model_sdpa = model_sdpa.eval()

                with torch.no_grad():
                    with torch.backends.cuda.sdp_kernel(
                        enable_flash=False,
                        enable_math=True,
                        enable_mem_efficient=False,
                    ):
                        res_eager = model_eager(**inputs_dict, return_dict=False)[0]
                        res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]

                # Only non-padding tokens are expected to match.
                self.assertTrue(
4107
                    torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
4108
4109
                )

4110
4111
4112
4113
4114
    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_generate_use_cache(self):
amyeroberts's avatar
amyeroberts committed
4115
4116
4117
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

4118
4119
        max_new_tokens = 30

4120
4121
        for model_class in self.all_generative_model_classes:
            if not model_class._supports_flash_attn_2:
4122
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
4123

4124
4125
4126
4127
4128
4129
4130
4131
4132
4133
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            dummy_input = inputs_dict[model_class.main_input_name]
            if dummy_input.dtype in [torch.float32, torch.bfloat16]:
                dummy_input = dummy_input.to(torch.float16)

            # make sure that all models have enough positions for generation
            if hasattr(config, "max_position_embeddings"):
                config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

4134
4135
4136
4137
4138
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)

4139
                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
4140
4141

                model = model_class.from_pretrained(
4142
4143
                    tmpdirname,
                    torch_dtype=torch.float16,
4144
                    attn_implementation="flash_attention_2",
4145
                    low_cpu_mem_usage=True,
4146
4147
4148
4149
                ).to(torch_device)

                # Just test that a large cache works as expected
                _ = model.generate(
4150
4151
4152
4153
4154
                    dummy_input,
                    attention_mask=dummy_attention_mask,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    use_cache=True,
4155
4156
                )

4157
4158
4159
4160
4161
4162
    @require_flash_attn
    @require_torch_gpu
    @require_bitsandbytes
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_fp32_ln(self):
amyeroberts's avatar
amyeroberts committed
4163
4164
4165
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

4166
4167
        for model_class in self.all_generative_model_classes:
            if not model_class._supports_flash_attn_2:
4168
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
4169
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
4170
4171
4172
4173
            model = model_class(config)
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)

4174
4175
                dummy_input = inputs_dict[model.main_input_name]
                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
4176
                batch_size = dummy_attention_mask.shape[0]
4177

4178
4179
4180
4181
4182
                is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size

                # To avoid errors with padding_side=="right"
                if is_padding_right:
                    dummy_attention_mask = torch.ones_like(dummy_input)
4183
4184
4185
4186

                model = model_class.from_pretrained(
                    tmpdirname,
                    torch_dtype=torch.float16,
4187
                    attn_implementation="flash_attention_2",
4188
4189
4190
4191
4192
4193
4194
4195
4196
                    low_cpu_mem_usage=True,
                    load_in_4bit=True,
                )

                for _, param in model.named_parameters():
                    # upcast only layer norms
                    if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
                        param.data = param.data.to(torch.float32)

4197
                if model.config.is_encoder_decoder:
4198
4199
4200
                    dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
                    dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]

4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
                    _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
                    # with attention mask
                    _ = model(
                        dummy_input,
                        attention_mask=dummy_attention_mask,
                        decoder_input_ids=dummy_decoder_input_ids,
                        decoder_attention_mask=dummy_decoder_attention_mask,
                    )
                else:
                    _ = model(dummy_input)
                    # with attention mask
                    _ = model(dummy_input, attention_mask=dummy_attention_mask)
4213

4214
4215
4216
4217
4218
4219
4220
    @is_pt_tf_cross_test
    def test_tf_from_pt_safetensors(self):
        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            tf_model_class_name = "TF" + model_class.__name__  # Add the "TF" at the beginning
            if not hasattr(transformers, tf_model_class_name):
amyeroberts's avatar
amyeroberts committed
4221
                self.skipTest(reason="transformers does not have this model in TF version yet")
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
4241
4242
4243
4244

            tf_model_class = getattr(transformers, tf_model_class_name)

            pt_model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_model.save_pretrained(tmpdirname, safe_serialization=True)
                tf_model_1 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)

                pt_model.save_pretrained(tmpdirname, safe_serialization=False)
                tf_model_2 = tf_model_class.from_pretrained(tmpdirname, from_pt=True)

                # Check models are equal
                for p1, p2 in zip(tf_model_1.weights, tf_model_2.weights):
                    self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))

    @is_pt_flax_cross_test
    def test_flax_from_pt_safetensors(self):
        for model_class in self.all_model_classes:
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

            flax_model_class_name = "Flax" + model_class.__name__  # Add the "Flax at the beginning
            if not hasattr(transformers, flax_model_class_name):
amyeroberts's avatar
amyeroberts committed
4245
                self.skipTest(reason="transformers does not have this model in Flax version yet")
4246
4247
4248
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260

            flax_model_class = getattr(transformers, flax_model_class_name)

            pt_model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_model.save_pretrained(tmpdirname, safe_serialization=True)
                flax_model_1 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)

                pt_model.save_pretrained(tmpdirname, safe_serialization=False)
                flax_model_2 = flax_model_class.from_pretrained(tmpdirname, from_pt=True)

                # Check models are equal
                self.assertTrue(check_models_equal(flax_model_1, flax_model_2))

4261
4262
4263
4264
4265
    @require_flash_attn
    @require_torch_gpu
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_from_config(self):
amyeroberts's avatar
amyeroberts committed
4266
4267
4268
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

4269
4270
        for model_class in self.all_generative_model_classes:
            if not model_class._supports_flash_attn_2:
4271
                self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
4272
4273
4274
4275

            config, _ = self.model_tester.prepare_config_and_inputs_for_common()
            # TODO: to change it in the future with other relevant auto classes
            fa2_model = AutoModelForCausalLM.from_config(
4276
                config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
            ).to(torch_device)

            dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
            dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)

            fa2_correctly_converted = False

            for _, module in fa2_model.named_modules():
                if "FlashAttention" in module.__class__.__name__:
                    fa2_correctly_converted = True
                    break

            self.assertTrue(fa2_correctly_converted)

            _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)

            with tempfile.TemporaryDirectory() as tmpdirname:
                fa2_model.save_pretrained(tmpdirname)

                model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)

4298
                self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308

                fa2_correctly_converted = False

                for _, module in model_from_pretrained.named_modules():
                    if "FlashAttention" in module.__class__.__name__:
                        fa2_correctly_converted = True
                        break

                self.assertFalse(fa2_correctly_converted)

4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320
4321
4322
4323
4324
4325
4326
4327
4328
4329
4330
4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
4342
4343
4344
    def _get_custom_4d_mask_test_data(self):
        # Sequence in which all but the last token is the same
        input_ids = torch.tensor(
            [[10, 11, 12, 13], [10, 11, 12, 14], [10, 11, 12, 15]], device=torch_device, dtype=torch.int64
        )
        position_ids = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)

        # Combining common prefix with the unique ending tokens:
        input_ids_shared_prefix = torch.cat([input_ids[0][:-1], input_ids[:, -1]]).unsqueeze(0)

        # Creating a 4D mask where each of the last 3 tokens do not attend to each other.
        mask_shared_prefix = torch.tensor(
            [
                [
                    [
                        [1, 0, 0, 0, 0, 0],
                        [1, 1, 0, 0, 0, 0],
                        [1, 1, 1, 0, 0, 0],
                        [1, 1, 1, 1, 0, 0],
                        [1, 1, 1, 0, 1, 0],
                        [1, 1, 1, 0, 0, 1],
                    ]
                ]
            ],
        )
        # inverting the attention mask
        mask_dtype = torch.float32
        min_dtype = torch.finfo(mask_dtype).min
        mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=mask_dtype, device=torch_device) * min_dtype

        # Creating a position_ids tensor. note the repeating figures in the end.
        position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)

        return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix

    def test_custom_4d_attention_mask(self):
amyeroberts's avatar
amyeroberts committed
4345
4346
4347
        if not self.has_attentions:
            self.skipTest(reason="Model architecture does not support attentions")

4348
        if len(self.all_generative_model_classes) == 0:
amyeroberts's avatar
amyeroberts committed
4349
4350
4351
            self.skipTest(
                reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
            )
4352
4353

        for model_class in self.all_generative_model_classes:
4354
            if not model_class._supports_static_cache:
4355
4356
                self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
            config, _ = self.model_tester.prepare_config_and_inputs_for_common()
4357
4358
            if getattr(config, "sliding_window", 0) > 0:
                self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
4382
4383
4384
4385
4386
4387
4388
4389
            model = model_class(config).to(device=torch_device, dtype=torch.float32)

            (
                input_ids,
                position_ids,
                input_ids_shared_prefix,
                mask_shared_prefix,
                position_ids_shared_prefix,
            ) = self._get_custom_4d_mask_test_data()

            logits = model.forward(input_ids, position_ids=position_ids).logits
            # logits.shape == torch.Size([3, 4, ...])

            logits_shared_prefix = model(
                input_ids_shared_prefix,
                attention_mask=mask_shared_prefix,
                position_ids=position_ids_shared_prefix,
            )[0]
            # logits_shared_prefix.shape == torch.Size([1, 6, ...])

            out_last_tokens = logits[:, -1, :]  # last tokens in each batch line
            out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :]  # last three tokens

            # comparing greedily-chosen tokens:
            assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)

            # comparing softmax-normalized logits:
            normalized_0 = F.softmax(out_last_tokens)
            normalized_1 = F.softmax(out_shared_prefix_last_tokens)
            torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

4390
4391
4392
4393
4394
4395
    # For now, Let's focus only on GPU for `torch.compile`
    @slow
    @require_torch_gpu
    @require_read_token
    def test_torch_compile(self):
        if version.parse(torch.__version__) < version.parse("2.3"):
amyeroberts's avatar
amyeroberts committed
4396
            self.skipTest(reason="This test requires torch >= 2.3 to run.")
4397
4398
4399
4400
4401
4402
4403
4404
4405
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416
4417
4418
4419
4420
4421

        if not hasattr(self, "_torch_compile_test_ckpt"):
            self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
        ckpt = self._torch_compile_test_ckpt

        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        batch_size = 1
        n_iter = 3

        tokenizer = AutoTokenizer.from_pretrained(ckpt)
        model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)

        model.generation_config.max_new_tokens = 4
        model.generation_config.max_new_tokens = 4

        model.generation_config.cache_implementation = "static"
        model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

        input_text = "Why dogs are cute?"
        input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)

        for i in range(n_iter):
            _ = model.generate(**input_ids, do_sample=False)

4422

4423
global_rng = random.Random()
thomwolf's avatar
thomwolf committed
4424
4425


thomwolf's avatar
thomwolf committed
4426
def ids_tensor(shape, vocab_size, rng=None, name=None):
4427
    #  Creates a random int32 tensor of the shape within the vocab size
thomwolf's avatar
thomwolf committed
4428
    if rng is None:
4429
        rng = global_rng
thomwolf's avatar
thomwolf committed
4430

thomwolf's avatar
thomwolf committed
4431
4432
4433
    total_dims = 1
    for dim in shape:
        total_dims *= dim
thomwolf's avatar
thomwolf committed
4434

thomwolf's avatar
thomwolf committed
4435
4436
4437
    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))
thomwolf's avatar
thomwolf committed
4438

4439
    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
4440
4441


4442
4443
4444
def random_attention_mask(shape, rng=None, name=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
    # make sure that at least one token is attended to for each batch
4445
4446
    # we choose the 1st token so this property of `at least one being non-zero` still holds after applying causal mask
    attn_mask[:, 0] = 1
4447
4448
4449
    return attn_mask


4450
def floats_tensor(shape, scale=1.0, rng=None, name=None):
Patrick von Platen's avatar
Patrick von Platen committed
4451
    """Creates a random float32 tensor"""
4452
4453
4454
4455
4456
4457
4458
4459
4460
4461
4462
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

4463
    return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()