test_modeling_common.py 55.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
import copy
17
import gc
18
import inspect
19
import os.path
Aymeric Augustin's avatar
Aymeric Augustin committed
20
import random
21
import tempfile
thomwolf's avatar
thomwolf committed
22
import unittest
23
from typing import List, Tuple
thomwolf's avatar
thomwolf committed
24

25
from transformers import is_torch_available
26
from transformers.file_utils import WEIGHTS_NAME
27
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
28

Aymeric Augustin's avatar
Aymeric Augustin committed
29

30
if is_torch_available():
31
    import numpy as np
32
    import torch
thomwolf's avatar
thomwolf committed
33

34
    from transformers import (
35
        BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
36
37
        MODEL_FOR_CAUSAL_LM_MAPPING,
        MODEL_FOR_MASKED_LM_MAPPING,
38
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
39
        MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
40
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
41
42
43
        MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
        MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
44
        MODEL_MAPPING,
45
46
47
48
49
        AdaptiveEmbedding,
        BertConfig,
        BertModel,
        PretrainedConfig,
        PreTrainedModel,
50
        T5ForConditionalGeneration,
51
    )
thomwolf's avatar
thomwolf committed
52

53

54
55
56
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
57
        if "_range" in key or "_std" in key or "initializer_factor" in key:
Lysandre Debut's avatar
Lysandre Debut committed
58
            setattr(configs_no_init, key, 1e-10)
59
60
    return configs_no_init

thomwolf's avatar
thomwolf committed
61

62
63
64
TINY_T5 = "patrickvonplaten/t5-tiny-random"


65
66
67
68
69
@require_torch
class ModelTesterMixin:

    model_tester = None
    all_model_classes = ()
70
    all_generative_model_classes = ()
Patrick von Platen's avatar
Patrick von Platen committed
71
72
73
74
    test_torchscript = True
    test_pruning = True
    test_resize_embeddings = True
    test_head_masking = True
75
    test_missing_keys = True
76
    test_model_parallel = False
77
78
    is_encoder_decoder = False

79
80
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = copy.deepcopy(inputs_dict)
81
        if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
82
            inputs_dict = {
83
                k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
84
                if isinstance(v, torch.Tensor) and v.ndim > 1
Sylvain Gugger's avatar
Sylvain Gugger committed
85
                else v
86
87
                for k, v in inputs_dict.items()
            }
88
89
90
91
92
93
94
95
96
97
98

        if return_labels:
            if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
            elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                inputs_dict["start_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
                inputs_dict["end_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
99
100
101
102
            elif model_class in [
                *MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
                *MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
            ]:
103
104
105
106
107
108
109
110
111
112
113
114
                inputs_dict["labels"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                )
            elif model_class in [
                *MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
                *MODEL_FOR_CAUSAL_LM_MAPPING.values(),
                *MODEL_FOR_MASKED_LM_MAPPING.values(),
                *MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
            ]:
                inputs_dict["labels"] = torch.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
                )
115
116
        return inputs_dict

Patrick von Platen's avatar
Patrick von Platen committed
117
    def test_save_load(self):
118
119
120
121
122
123
124
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
125
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
Weizhen's avatar
Weizhen committed
126

127
            out_2 = outputs[0].cpu().numpy()
128
            out_2[np.isnan(out_2)] = 0
129

130
            with tempfile.TemporaryDirectory() as tmpdirname:
131
132
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)
133
                model.to(torch_device)
134
                with torch.no_grad():
135
                    after_outputs = model(**self._prepare_for_class(inputs_dict, model_class))
thomwolf's avatar
thomwolf committed
136

137
138
139
                # Make sure we don't have nans
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
thomwolf's avatar
thomwolf committed
140
141
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
142

143
    def test_save_load__keys_to_ignore_on_save(self):
144
145
146
147
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
148
149
            _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
            if _keys_to_ignore_on_save is None:
150
151
152
                continue

            # check the keys are in the original state_dict
153
            for k in _keys_to_ignore_on_save:
154
155
156
157
158
159
160
                self.assertIn(k, model.state_dict())

            # check that certain keys didn't get saved with the model
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
                state_dict_saved = torch.load(output_model_file)
161
                for k in _keys_to_ignore_on_save:
162
163
                    self.assertNotIn(k, state_dict_saved)

Patrick von Platen's avatar
Patrick von Platen committed
164
    def test_initialization(self):
165
166
167
168
169
170
171
172
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        configs_no_init = _config_zero_init(config)
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.assertIn(
Lysandre Debut's avatar
Lysandre Debut committed
173
                        ((param.data.mean() * 1e9).round() / 1e9).item(),
174
                        [0.0, 1.0],
175
                        msg=f"Parameter {name} of model {model_class} seems not properly initialized",
176
                    )
thomwolf's avatar
thomwolf committed
177

Patrick von Platen's avatar
Patrick von Platen committed
178
    def test_determinism(self):
179
180
181
182
183
184
185
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
186
187
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
Weizhen's avatar
Weizhen committed
188

189
190
191
192
193
194
195
            out_1 = first.cpu().numpy()
            out_2 = second.cpu().numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def test_forward_signature(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            signature = inspect.signature(model.forward)
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
            arg_names = [*signature.parameters.keys()]

            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]
212
213
214
215
216
217
                expected_arg_names.extend(
                    ["head_mask", "decoder_head_mask", "encoder_outputs"]
                    if "head_mask" and "decoder_head_mask" in arg_names
                    else ["encoder_outputs"]
                )
                self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
218
219
220
221
            else:
                expected_arg_names = ["input_ids"]
                self.assertListEqual(arg_names[:1], expected_arg_names)

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def test_training(self):
        if not self.model_tester.is_training:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.return_dict = True

        for model_class in self.all_model_classes:
            if model_class in MODEL_MAPPING.values():
                continue
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()

    def test_training_gradient_checkpointing(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
            return

        config.gradient_checkpointing = True
245
        config.use_cache = False
246
247
248
249
250
251
252
253
254
255
256
257
        config.return_dict = True

        for model_class in self.all_model_classes:
            if model_class in MODEL_MAPPING.values():
                continue
            model = model_class(config)
            model.to(torch_device)
            model.train()
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            loss = model(**inputs).loss
            loss.backward()

Patrick von Platen's avatar
Patrick von Platen committed
258
    def test_attention_outputs(self):
259
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Weizhen's avatar
Weizhen committed
260
261
        config.return_dict = True

sshleifer's avatar
sshleifer committed
262
        seq_len = getattr(self.model_tester, "seq_length", None)
sshleifer's avatar
sshleifer committed
263
264
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
Weizhen's avatar
Weizhen committed
265
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
266
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
Patrick von Platen's avatar
Patrick von Platen committed
267
268
269
        chunk_length = getattr(self.model_tester, "chunk_length", None)
        if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
270
271

        for model_class in self.all_model_classes:
272
            inputs_dict["output_attentions"] = True
Joseph Liu's avatar
Joseph Liu committed
273
            inputs_dict["output_hidden_states"] = False
274
            config.return_dict = True
275
276
277
278
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
279
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
280
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
281
282
283
284
285
286
287
288
289
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
290
291
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
292
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
293
294
295
296
297
298
299
300
301
302
303

            if chunk_length is not None:
                self.assertListEqual(
                    list(attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
304
            out_len = len(outputs)
thomwolf's avatar
thomwolf committed
305

306
            if self.is_encoder_decoder:
307
                correct_outlen = 5
308

309
310
311
312
313
314
                # loss is at first position
                if "labels" in inputs_dict:
                    correct_outlen += 1  # loss is added to beginning
                # Question Answering model returns start_logits and end_logits
                if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output
315
316
                if "past_key_values" in outputs:
                    correct_outlen += 1  # past_key_values have been returned
Weizhen's avatar
Weizhen committed
317

Sam Shleifer's avatar
Sam Shleifer committed
318
319
                self.assertEqual(out_len, correct_outlen)

320
                # decoder attentions
321
                decoder_attentions = outputs.decoder_attentions
Sam Shleifer's avatar
Sam Shleifer committed
322
                self.assertIsInstance(decoder_attentions, (list, tuple))
323
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
thomwolf's avatar
thomwolf committed
324
                self.assertListEqual(
325
326
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
327
                )
thomwolf's avatar
thomwolf committed
328

329
330
331
332
333
334
335
336
337
338
339
340
341
                # cross attentions
                cross_attentions = outputs.cross_attentions
                self.assertIsInstance(cross_attentions, (list, tuple))
                self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(cross_attentions[0].shape[-3:]),
                    [
                        self.model_tester.num_attention_heads,
                        decoder_seq_length,
                        encoder_key_length,
                    ],
                )

342
            # Check attention is always last and order is fine
343
            inputs_dict["output_attentions"] = True
Joseph Liu's avatar
Joseph Liu committed
344
            inputs_dict["output_hidden_states"] = True
345
346
347
348
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
349
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
350

Weizhen's avatar
Weizhen committed
351
352
353
354
355
356
357
358
            if hasattr(self.model_tester, "num_hidden_states_types"):
                added_hidden_states = self.model_tester.num_hidden_states_types
            elif self.is_encoder_decoder:
                added_hidden_states = 2
            else:
                added_hidden_states = 1
            self.assertEqual(out_len + added_hidden_states, len(outputs))

359
360
            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions

361
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
362
363
364
365
366
367
368
369
370
371
            if chunk_length is not None:
                self.assertListEqual(
                    list(self_attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(self_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
thomwolf's avatar
thomwolf committed
372

Patrick von Platen's avatar
Patrick von Platen committed
373
    def test_torchscript(self):
374
375
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
376

Patrick von Platen's avatar
Patrick von Platen committed
377
    def test_torchscript_output_attentions(self):
378
379
380
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_attentions = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
381

Patrick von Platen's avatar
Patrick von Platen committed
382
    def test_torchscript_output_hidden_state(self):
383
384
385
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
386

387
    def _create_and_check_torchscript(self, config, inputs_dict):
Patrick von Platen's avatar
Patrick von Platen committed
388
        if not self.test_torchscript:
389
            return
390

391
392
393
394
395
396
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        configs_no_init.torchscript = True
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
397
            inputs = self._prepare_for_class(inputs_dict, model_class)
thomwolf's avatar
thomwolf committed
398

399
            try:
400
                if model.config.is_encoder_decoder:
401
                    model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
402
403
404
405
406
407
408
409
410
411
                    input_ids = inputs["input_ids"]
                    attention_mask = inputs["attention_mask"]
                    decoder_input_ids = inputs["decoder_input_ids"]
                    decoder_attention_mask = inputs["decoder_attention_mask"]
                    traced_model = torch.jit.trace(
                        model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
                    )
                else:
                    input_ids = inputs["input_ids"]
                    traced_model = torch.jit.trace(model, input_ids)
412
413
            except RuntimeError:
                self.fail("Couldn't trace module.")
thomwolf's avatar
thomwolf committed
414

415
            with tempfile.TemporaryDirectory() as tmp_dir_name:
416
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
thomwolf's avatar
thomwolf committed
417

418
                try:
419
                    torch.jit.save(traced_model, pt_file_name)
420
421
                except Exception:
                    self.fail("Couldn't save module.")
thomwolf's avatar
thomwolf committed
422

423
424
425
426
                try:
                    loaded_model = torch.jit.load(pt_file_name)
                except Exception:
                    self.fail("Couldn't load module.")
LysandreJik's avatar
LysandreJik committed
427

428
429
            model.to(torch_device)
            model.eval()
thomwolf's avatar
thomwolf committed
430

431
432
            loaded_model.to(torch_device)
            loaded_model.eval()
thomwolf's avatar
thomwolf committed
433

434
435
436
437
            model_state_dict = model.state_dict()
            loaded_model_state_dict = loaded_model.state_dict()

            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
thomwolf's avatar
thomwolf committed
438

439
            models_equal = True
440
441
            for layer_name, p1 in model_state_dict.items():
                p2 = loaded_model_state_dict[layer_name]
442
443
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False
thomwolf's avatar
thomwolf committed
444

445
            self.assertTrue(models_equal)
thomwolf's avatar
thomwolf committed
446

Patrick von Platen's avatar
Patrick von Platen committed
447
448
    def test_headmasking(self):
        if not self.test_head_masking:
449
            return
450

451
452
453
        global_rng.seed(42)
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        global_rng.seed()
LysandreJik's avatar
LysandreJik committed
454

455
        inputs_dict["output_attentions"] = True
456
457
458
459
460
461
        config.output_hidden_states = True
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
LysandreJik's avatar
LysandreJik committed
462

463
464
465
            # Prepare head_mask
            # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
            head_mask = torch.ones(
Lysandre's avatar
Lysandre committed
466
467
468
                self.model_tester.num_hidden_layers,
                self.model_tester.num_attention_heads,
                device=torch_device,
469
470
471
472
            )
            head_mask[0, 0] = 0
            head_mask[-1, :-1] = 0
            head_mask.requires_grad_(requires_grad=True)
473
            inputs = self._prepare_for_class(inputs_dict, model_class).copy()
474
            inputs["head_mask"] = head_mask
475
476
477
478
479
            if model.config.is_encoder_decoder:
                signature = inspect.signature(model.forward)
                arg_names = [*signature.parameters.keys()]
                if "decoder_head_mask" in arg_names:  # necessary diferentiation because of T5 model
                    inputs["decoder_head_mask"] = head_mask
480
            outputs = model(**inputs, return_dict=True)
481
482
483
484
485
486
487
488
489

            # Test that we can get a gradient back for importance score computation
            output = sum(t.sum() for t in outputs[0])
            output = output.sum()
            output.backward()
            multihead_outputs = head_mask.grad

            self.assertIsNotNone(multihead_outputs)
            self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512

            def check_attentions_validity(attentions):
                # Remove Nan
                for t in attentions:
                    self.assertLess(
                        torch.sum(torch.isnan(t)), t.numel() / 4
                    )  # Check we don't have more than 25% nans (arbitrary)
                attentions = [
                    t.masked_fill(torch.isnan(t), 0.0) for t in attentions
                ]  # remove them (the test is less complete)

                self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
                self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
                if len(attentions) > 2:  # encoder-decoder models have only 2 layers in each module
                    self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
                self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
                self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)

            if model.config.is_encoder_decoder:
                check_attentions_validity(outputs.encoder_attentions)
                check_attentions_validity(outputs.decoder_attentions)
            else:
                check_attentions_validity(outputs.attentions)
513

Patrick von Platen's avatar
Patrick von Platen committed
514
515
    def test_head_pruning(self):
        if not self.test_pruning:
516
517
518
            return

        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
519
520
521
522
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
523

524
525
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
526

527
            inputs_dict["output_attentions"] = True
528
529
530
531
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
532
533
534
535
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
536
537
            model.prune_heads(heads_to_prune)
            with torch.no_grad():
538
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
539

540
            attentions = outputs[-1]
541

542
543
544
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
LysandreJik's avatar
LysandreJik committed
545

Patrick von Platen's avatar
Patrick von Platen committed
546
547
    def test_head_pruning_save_load_from_pretrained(self):
        if not self.test_pruning:
548
            return
LysandreJik's avatar
LysandreJik committed
549

550
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
551
552
553
554
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
555
556
557

            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
558

559
            inputs_dict["output_attentions"] = True
560
561
562
563
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
564
565
566
567
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
568
            model.prune_heads(heads_to_prune)
569

570
            with tempfile.TemporaryDirectory() as temp_dir_name:
571
572
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
573
                model.to(torch_device)
574

575
            with torch.no_grad():
576
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
577
578
579
580
            attentions = outputs[-1]
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
581

Patrick von Platen's avatar
Patrick von Platen committed
582
583
    def test_head_pruning_save_load_from_config_init(self):
        if not self.test_pruning:
584
            return
585

586
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
587
588
589
590
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
591

592
593
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
594

595
            inputs_dict["output_attentions"] = True
596
            config.output_hidden_states = False
597

598
599
600
601
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
602
            config.pruned_heads = heads_to_prune
603

604
605
606
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
607

608
            with torch.no_grad():
609
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
610
            attentions = outputs[-1]
611

612
613
614
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
615

Patrick von Platen's avatar
Patrick von Platen committed
616
617
    def test_head_pruning_integration(self):
        if not self.test_pruning:
618
            return
619

620
        for model_class in self.all_model_classes:
Lysandre's avatar
Lysandre committed
621
622
623
624
            (
                config,
                inputs_dict,
            ) = self.model_tester.prepare_config_and_inputs_for_common()
625

626
627
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
628

629
            inputs_dict["output_attentions"] = True
630
            config.output_hidden_states = False
631

632
633
            heads_to_prune = {0: [0], 1: [1, 2]}
            config.pruned_heads = heads_to_prune
634

635
636
637
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
638

639
            with torch.no_grad():
640
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
641
            attentions = outputs[-1]
642

643
644
645
646
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
647

648
            with tempfile.TemporaryDirectory() as temp_dir_name:
649
650
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
651
                model.to(torch_device)
thomwolf's avatar
thomwolf committed
652

653
            with torch.no_grad():
654
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
655
            attentions = outputs[-1]
LysandreJik's avatar
LysandreJik committed
656

657
658
659
660
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
661

662
663
            heads_to_prune = {0: [0], 2: [1, 2]}
            model.prune_heads(heads_to_prune)
664

665
            with torch.no_grad():
666
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
667
            attentions = outputs[-1]
668

669
670
671
672
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
673

674
            self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
thomwolf's avatar
thomwolf committed
675

Patrick von Platen's avatar
Patrick von Platen committed
676
    def test_hidden_states_output(self):
Joseph Liu's avatar
Joseph Liu committed
677
        def check_hidden_states_output(inputs_dict, config, model_class):
678
            model = model_class(config)
679
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
680
            model.eval()
Joseph Liu's avatar
Joseph Liu committed
681

thomwolf's avatar
thomwolf committed
682
            with torch.no_grad():
683
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
684
685

            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
Patrick von Platen's avatar
Patrick von Platen committed
686

Sylvain Gugger's avatar
Sylvain Gugger committed
687
688
689
690
            expected_num_layers = getattr(
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
            )
            self.assertEqual(len(hidden_states), expected_num_layers)
691

Patrick von Platen's avatar
Patrick von Platen committed
692
693
694
695
696
697
698
            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
                if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
                    seq_length = seq_length * self.model_tester.chunk_length
            else:
                seq_length = self.model_tester.seq_length

699
            self.assertListEqual(
Lysandre's avatar
Lysandre committed
700
701
                list(hidden_states[0].shape[-2:]),
                [seq_length, self.model_tester.hidden_size],
702
            )
thomwolf's avatar
thomwolf committed
703

704
705
706
707
708
709
710
711
712
713
714
715
716
            if config.is_encoder_decoder:
                hidden_states = outputs.decoder_hidden_states

                self.assertIsInstance(hidden_states, (list, tuple))
                self.assertEqual(len(hidden_states), expected_num_layers)
                seq_len = getattr(self.model_tester, "seq_length", None)
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)

                self.assertListEqual(
                    list(hidden_states[0].shape[-2:]),
                    [decoder_seq_length, self.model_tester.hidden_size],
                )

Joseph Liu's avatar
Joseph Liu committed
717
718
719
720
721
722
723
724
725
726
727
728
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)

729
730
731
732
733
734
735
736
737
738
739
740
741
    def test_retain_grad_hidden_states_attentions(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
        config.output_attentions = True

        # no need to test all models as different heads yield the same functionality
        model_class = self.all_model_classes[0]
        model = model_class(config)
        model.to(torch_device)

        inputs = self._prepare_for_class(inputs_dict, model_class)

        outputs = model(**inputs)
742
743

        print(outputs)
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        output = outputs[0]

        if config.is_encoder_decoder:
            # Seq2Seq models
            encoder_hidden_states = outputs.encoder_hidden_states[0]
            encoder_attentions = outputs.encoder_attentions[0]
            encoder_hidden_states.retain_grad()
            encoder_attentions.retain_grad()

            decoder_hidden_states = outputs.decoder_hidden_states[0]
            decoder_attentions = outputs.decoder_attentions[0]
            decoder_hidden_states.retain_grad()
            decoder_attentions.retain_grad()

            cross_attentions = outputs.cross_attentions[0]
            cross_attentions.retain_grad()

            output.flatten()[0].backward(retain_graph=True)

            self.assertIsNotNone(encoder_hidden_states.grad)
            self.assertIsNotNone(encoder_attentions.grad)
            self.assertIsNotNone(decoder_hidden_states.grad)
            self.assertIsNotNone(decoder_attentions.grad)
            self.assertIsNotNone(cross_attentions.grad)
        else:
            # Encoder-/Decoder-only models
            hidden_states = outputs.hidden_states[0]
            attentions = outputs.attentions[0]

            hidden_states.retain_grad()
            attentions.retain_grad()

            output.flatten()[0].backward(retain_graph=True)

            self.assertIsNotNone(hidden_states.grad)
            self.assertIsNotNone(attentions.grad)

Pradhy729's avatar
Pradhy729 committed
781
    def test_feed_forward_chunking(self):
Lysandre's avatar
Lysandre committed
782
783
784
785
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
Pradhy729's avatar
Pradhy729 committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        for model_class in self.all_model_classes:
            torch.manual_seed(0)
            config = copy.deepcopy(original_config)
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]

            torch.manual_seed(0)
            config.chunk_size_feed_forward = 1
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
            self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))

Patrick von Platen's avatar
Patrick von Platen committed
804
    def test_resize_tokens_embeddings(self):
Lysandre's avatar
Lysandre committed
805
806
807
808
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
Patrick von Platen's avatar
Patrick von Platen committed
809
        if not self.test_resize_embeddings:
810
811
812
813
814
            return

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config)
815
            model.to(torch_device)
816

Patrick von Platen's avatar
Patrick von Platen committed
817
818
819
            if self.model_tester.is_training is False:
                model.eval()

820
821
822
823
824
825
826
827
828
829
            model_vocab_size = config.vocab_size
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
830
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
831
            model(**self._prepare_for_class(inputs_dict, model_class))
832
833
834
835
836
837
838

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)

839
840
841
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
842
843
844
845

            # make sure that decoder_input_ids are resized as well
            if "decoder_input_ids" in inputs_dict:
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
846
            model(**self._prepare_for_class(inputs_dict, model_class))
847

848
849
850
851
852
853
854
855
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False

            self.assertTrue(models_equal)

856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    def test_resize_embeddings_untied(self):
        (
            original_config,
            inputs_dict,
        ) = self.model_tester.prepare_config_and_inputs_for_common()
        if not self.test_resize_embeddings:
            return

        original_config.tie_word_embeddings = False

        # if model cannot untied embeddings -> leave test
        if original_config.tie_word_embeddings:
            return

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config).to(torch_device)

            # if no output embeddings -> leave test
            if model.get_output_embeddings() is None:
                continue

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_vocab_size = config.vocab_size
            model.resize_token_embeddings(model_vocab_size + 10)
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
            output_embeds = model.get_output_embeddings()
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
            # Check bias if present
            if output_embeds.bias is not None:
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model.resize_token_embeddings(model_vocab_size - 15)
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
            # Check that it actually resizes the embeddings matrix
            output_embeds = model.get_output_embeddings()
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
            # Check bias if present
            if output_embeds.bias is not None:
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
            if "decoder_input_ids" in inputs_dict:
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            model(**self._prepare_for_class(inputs_dict, model_class))

Patrick von Platen's avatar
Patrick von Platen committed
907
    def test_model_common_attributes(self):
908
909
910
911
912
913
914
915
916
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
            model.set_input_embeddings(torch.nn.Embedding(10, 10))
            x = model.get_output_embeddings()
            self.assertTrue(x is None or isinstance(x, torch.nn.Linear))

917
    def test_correct_missing_keys(self):
918
919
        if not self.test_missing_keys:
            return
920
921
922
923
924
925
926
927
928
929
930
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            base_model_prefix = model.base_model_prefix

            if hasattr(model, base_model_prefix):
                with tempfile.TemporaryDirectory() as temp_dir_name:
                    model.base_model.save_pretrained(temp_dir_name)
                    model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)

931
                    with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"):
932
933
                        self.assertGreater(len(loading_info["missing_keys"]), 0)

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
    def test_tie_model_weights(self):
        if not self.test_torchscript:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_same_values(layer_1, layer_2):
            equal = True
            for p1, p2 in zip(layer_1.weight, layer_2.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    equal = False
            return equal

        for model_class in self.all_model_classes:
            config.torchscript = True
            model_not_tied = model_class(config)
            if model_not_tied.get_output_embeddings() is None:
                continue

            config_tied = copy.deepcopy(config)
            config_tied.torchscript = False
            model_tied = model_class(config_tied)
            params_tied = list(model_tied.parameters())
            # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # embeddings.weight.data.div_(2)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # decoding.weight.data.div_(4)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # Check that after resize they remain tied.
            model_tied.resize_token_embeddings(config.vocab_size + 10)
            params_tied_2 = list(model_tied.parameters())
            self.assertEqual(len(params_tied_2), len(params_tied))

            # decoding.weight.data.mul_(20)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
            # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

982
983
984
985
    def test_model_outputs_equivalence(self):

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Sam Shleifer's avatar
Sam Shleifer committed
986
987
988
989
        def set_nan_tensor_to_zero(t):
            t[t != t] = 0
            return t

990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            with torch.no_grad():
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
                dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

                def recursive_check(tuple_object, dict_object):
                    if isinstance(tuple_object, (List, Tuple)):
                        for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                            recursive_check(tuple_iterable_value, dict_iterable_value)
                    elif tuple_object is None:
                        return
                    else:
                        self.assertTrue(
Sam Shleifer's avatar
Sam Shleifer committed
1003
1004
1005
                            torch.allclose(
                                set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
                            ),
1006
                            msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
                        )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(
                model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
            )

Patrick von Platen's avatar
Patrick von Platen committed
1046
    def test_inputs_embeds(self):
1047
1048
1049
1050
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
1051
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
1052
            model.eval()
1053

1054
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
Weizhen's avatar
Weizhen committed
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
            if not self.is_encoder_decoder:
                input_ids = inputs["input_ids"]
                del inputs["input_ids"]
            else:
                encoder_input_ids = inputs["input_ids"]
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
                del inputs["input_ids"]
                inputs.pop("decoder_input_ids", None)

1065
1066
            wte = model.get_input_embeddings()
            if not self.is_encoder_decoder:
1067
                inputs["inputs_embeds"] = wte(input_ids)
1068
            else:
1069
1070
                inputs["inputs_embeds"] = wte(encoder_input_ids)
                inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
1071

thomwolf's avatar
thomwolf committed
1072
            with torch.no_grad():
Weizhen's avatar
Weizhen committed
1073
                model(**inputs)[0]
1074

1075
1076
    @require_torch_multi_gpu
    def test_multi_gpu_data_parallel_forward(self):
1077
1078
1079
1080
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        # some params shouldn't be scattered by nn.DataParallel
        # so just remove them if they are present.
Patrick von Platen's avatar
Patrick von Platen committed
1081
        blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
        for k in blacklist_non_batched_params:
            inputs_dict.pop(k, None)

        # move input tensors to cuda:O
        for k, v in inputs_dict.items():
            if torch.is_tensor(v):
                inputs_dict[k] = v.to(0)

        for model_class in self.all_model_classes:
            model = model_class(config=config)
            model.to(0)
            model.eval()

            # Wrap model in nn.DataParallel
            model = torch.nn.DataParallel(model)
            with torch.no_grad():
1098
                _ = model(**self._prepare_for_class(inputs_dict, model_class))
1099

1100
1101
1102
    @require_torch_multi_gpu
    def test_model_parallelization(self):
        if not self.test_model_parallel:
1103
            return
1104

1105
        # a candidate for testing_utils
1106
        def get_current_gpu_memory_use():
1107
1108
1109
1110
1111
1112
            """ returns a list of cuda memory allocations per GPU in MBs"""

            per_device_memory = []
            for id in range(torch.cuda.device_count()):
                with torch.cuda.device(id):
                    per_device_memory.append(torch.cuda.memory_allocated() >> 20)
1113
1114
1115
1116
1117
1118
1119
1120
1121

            return per_device_memory

        # Needs a large model to see the difference.
        config = self.model_tester.get_large_model_config()

        for model_class in self.all_parallelizable_model_classes:
            torch.cuda.empty_cache()

1122
1123
1124
            # 1. single gpu memory load + unload + memory measurements
            # Retrieve initial memory usage (can easily be ~0.6-1.5GB if cuda-kernels have been preloaded by previous tests)
            memory_at_start = get_current_gpu_memory_use()
1125

1126
1127
            # Put model on device 0 and take a memory snapshot
            model = model_class(config)
1128
1129
1130
            model.to("cuda:0")
            memory_after_model_load = get_current_gpu_memory_use()

1131
1132
1133
            # The memory use on device 0 should be higher than it was initially.
            self.assertGreater(memory_after_model_load[0], memory_at_start[0])

1134
            del model
1135
            gc.collect()
1136
1137
            torch.cuda.empty_cache()

1138
1139
1140
            # 2. MP test
            # it's essential to re-calibrate the usage before the next stage
            memory_at_start = get_current_gpu_memory_use()
1141
1142

            # Spread model layers over multiple devices
1143
            model = model_class(config)
1144
1145
1146
1147
1148
            model.parallelize()
            memory_after_parallelization = get_current_gpu_memory_use()

            # Assert that the memory use on all devices is higher than it was when loaded only on CPU
            for n in range(torch.cuda.device_count()):
1149
                self.assertGreater(memory_after_parallelization[n], memory_at_start[n])
1150

1151
            # Assert that the memory use of device 0 is lower than it was when the entire model was loaded on it
1152
1153
            self.assertLess(memory_after_parallelization[0], memory_after_model_load[0])

1154
1155
            # Assert that the memory use of device 1 is higher than it was when the entire model was loaded
            # on device 0 and device 1 wasn't used at all
1156
1157
1158
            self.assertGreater(memory_after_parallelization[1], memory_after_model_load[1])

            del model
1159
            gc.collect()
1160
1161
1162
1163
1164
            torch.cuda.empty_cache()

    @require_torch_multi_gpu
    def test_model_parallel_equal_results(self):
        if not self.test_model_parallel:
1165
            return
1166
1167
1168
1169
1170
1171

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_parallelizable_model_classes:
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)

1172
            def cast_to_device(dictionary, device):
1173
1174
1175
                output = {}
                for k, v in dictionary.items():
                    if isinstance(v, torch.Tensor):
1176
                        output[k] = v.to(device)
1177
1178
1179
1180
1181
                    else:
                        output[k] = v

                return output

1182
1183
1184
1185
1186
1187
            model = model_class(config)
            output = model(**cast_to_device(inputs_dict, "cpu"))

            model.parallelize()

            parallel_output = model(**cast_to_device(inputs_dict, "cuda:0"))
1188
1189
1190
1191
1192
1193
1194
1195

            for value, parallel_value in zip(output, parallel_output):
                if isinstance(value, torch.Tensor):
                    self.assertTrue(torch.allclose(value, parallel_value.to("cpu"), atol=1e-7))
                elif isinstance(value, (Tuple, List)):
                    for value_, parallel_value_ in zip(value, parallel_value):
                        self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
    @require_torch_multi_gpu
    def test_model_parallel_beam_search(self):
        if not self.test_model_parallel:
            return

        all_generative_and_parallelizable_model_classes = tuple(
            set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
        )

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in all_generative_and_parallelizable_model_classes:
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config)

            def cast_to_device(dictionary, device):
                output = {}
                for k, v in dictionary.items():
                    if isinstance(v, torch.Tensor):
                        output[k] = v.to(device)
                    else:
                        output[k] = v

                return output

            model.parallelize()
            model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)

1224

1225
global_rng = random.Random()
thomwolf's avatar
thomwolf committed
1226
1227


thomwolf's avatar
thomwolf committed
1228
def ids_tensor(shape, vocab_size, rng=None, name=None):
1229
    #  Creates a random int32 tensor of the shape within the vocab size
thomwolf's avatar
thomwolf committed
1230
    if rng is None:
1231
        rng = global_rng
thomwolf's avatar
thomwolf committed
1232

thomwolf's avatar
thomwolf committed
1233
1234
1235
    total_dims = 1
    for dim in shape:
        total_dims *= dim
thomwolf's avatar
thomwolf committed
1236

thomwolf's avatar
thomwolf committed
1237
1238
1239
    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))
thomwolf's avatar
thomwolf committed
1240

1241
    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
1242
1243


1244
1245
1246
1247
1248
1249
1250
def random_attention_mask(shape, rng=None, name=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask


1251
def floats_tensor(shape, scale=1.0, rng=None, name=None):
Patrick von Platen's avatar
Patrick von Platen committed
1252
    """Creates a random float32 tensor"""
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

1264
    return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
1265
1266


1267
@require_torch
thomwolf's avatar
thomwolf committed
1268
class ModelUtilsTest(unittest.TestCase):
1269
    @slow
Patrick von Platen's avatar
Patrick von Platen committed
1270
    def test_model_from_pretrained(self):
1271
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
thomwolf's avatar
thomwolf committed
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
            config = BertConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, PretrainedConfig)

            model = BertModel.from_pretrained(model_name)
            model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, PreTrainedModel)
            for value in loading_info.values():
                self.assertEqual(len(value), 0)

            config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
Lysandre Debut's avatar
Lysandre Debut committed
1284
1285
1286
1287

            # Not sure this is the intended behavior. TODO fix Lysandre & Thom
            config.name_or_path = model_name

thomwolf's avatar
thomwolf committed
1288
1289
1290
            model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
            self.assertEqual(model.config.output_hidden_states, True)
            self.assertEqual(model.config, config)
1291
1292
1293
1294
1295
1296
1297
1298

    def test_model_from_pretrained_with_different_pretrained_model_name(self):
        model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
        self.assertIsNotNone(model)

        with self.assertRaises(Exception) as context:
            BertModel.from_pretrained(TINY_T5)
        self.assertTrue("You tried to initiate a model of type" in str(context.exception))