test_modeling_common.py 40.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
import copy
17
import os.path
Aymeric Augustin's avatar
Aymeric Augustin committed
18
import random
19
import tempfile
thomwolf's avatar
thomwolf committed
20
import unittest
21
from typing import List
thomwolf's avatar
thomwolf committed
22

23
from transformers import is_torch_available
24
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
25

Aymeric Augustin's avatar
Aymeric Augustin committed
26

27
if is_torch_available():
thomwolf's avatar
thomwolf committed
28
    import torch
29
    import numpy as np
thomwolf's avatar
thomwolf committed
30

31
32
33
34
35
36
    from transformers import (
        AdaptiveEmbedding,
        PretrainedConfig,
        PreTrainedModel,
        BertModel,
        BertConfig,
37
        BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
38
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
39
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
40
        top_k_top_p_filtering,
41
    )
thomwolf's avatar
thomwolf committed
42

43

44
45
46
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
47
        if "_range" in key or "_std" in key or "initializer_factor" in key:
Lysandre Debut's avatar
Lysandre Debut committed
48
            setattr(configs_no_init, key, 1e-10)
49
50
    return configs_no_init

thomwolf's avatar
thomwolf committed
51

52
53
54
55
56
@require_torch
class ModelTesterMixin:

    model_tester = None
    all_model_classes = ()
57
    all_generative_model_classes = ()
Patrick von Platen's avatar
Patrick von Platen committed
58
59
60
61
    test_torchscript = True
    test_pruning = True
    test_resize_embeddings = True
    test_head_masking = True
62
    test_missing_keys = True
63
64
    is_encoder_decoder = False

65
66
67
68
    def _prepare_for_class(self, inputs_dict, model_class):
        if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            return {
                k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
69
70
                if isinstance(v, torch.Tensor) and v.ndim != 0
                else v
71
72
73
74
                for k, v in inputs_dict.items()
            }
        return inputs_dict

Patrick von Platen's avatar
Patrick von Platen committed
75
    def test_save_load(self):
76
77
78
79
80
81
82
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
83
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
84
            out_2 = outputs[0].cpu().numpy()
85
            out_2[np.isnan(out_2)] = 0
86

87
            with tempfile.TemporaryDirectory() as tmpdirname:
88
89
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)
90
                model.to(torch_device)
91
                with torch.no_grad():
92
                    after_outputs = model(**self._prepare_for_class(inputs_dict, model_class))
thomwolf's avatar
thomwolf committed
93

94
95
96
                # Make sure we don't have nans
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
thomwolf's avatar
thomwolf committed
97
98
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
99

Patrick von Platen's avatar
Patrick von Platen committed
100
    def test_initialization(self):
101
102
103
104
105
106
107
108
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        configs_no_init = _config_zero_init(config)
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.assertIn(
Lysandre Debut's avatar
Lysandre Debut committed
109
                        ((param.data.mean() * 1e9).round() / 1e9).item(),
110
111
112
                        [0.0, 1.0],
                        msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
                    )
thomwolf's avatar
thomwolf committed
113

Patrick von Platen's avatar
Patrick von Platen committed
114
    def test_determinism(self):
115
116
117
118
119
120
121
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
122
123
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
124
125
126
127
128
129
130
            out_1 = first.cpu().numpy()
            out_2 = second.cpu().numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

Patrick von Platen's avatar
Patrick von Platen committed
131
    def test_attention_outputs(self):
132
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
sshleifer's avatar
sshleifer committed
133
        seq_len = getattr(self.model_tester, "seq_length", None)
sshleifer's avatar
sshleifer committed
134
135
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
136
137
        decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
        chunk_length = getattr(self.model_tester, "chunk_length", None)
        if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
141
142

        for model_class in self.all_model_classes:
143
            inputs_dict["output_attentions"] = True
Joseph Liu's avatar
Joseph Liu committed
144
            inputs_dict["output_hidden_states"] = False
145
146
147
148
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
149
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
150
            attentions = outputs[-1]
151
152
153
154
155
156
157
158
159
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
Sylvain Gugger's avatar
Sylvain Gugger committed
160
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
161
            attentions = outputs[-1]
162
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
163
164
165
166
167
168
169
170
171
172
173

            if chunk_length is not None:
                self.assertListEqual(
                    list(attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
174
            out_len = len(outputs)
thomwolf's avatar
thomwolf committed
175

176
            if self.is_encoder_decoder:
177
                correct_outlen = 4
Sam Shleifer's avatar
Sam Shleifer committed
178
                decoder_attention_idx = 1
179

180
181
182
183
184
185
186
                # loss is at first position
                if "labels" in inputs_dict:
                    correct_outlen += 1  # loss is added to beginning
                    decoder_attention_idx += 1
                # Question Answering model returns start_logits and end_logits
                if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                    correct_outlen += 1  # start_logits and end_logits instead of only 1 output
Sam Shleifer's avatar
Sam Shleifer committed
187
188
189
190
191
                    decoder_attention_idx += 1
                self.assertEqual(out_len, correct_outlen)

                decoder_attentions = outputs[decoder_attention_idx]
                self.assertIsInstance(decoder_attentions, (list, tuple))
192
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
thomwolf's avatar
thomwolf committed
193
                self.assertListEqual(
194
195
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
196
                )
thomwolf's avatar
thomwolf committed
197

198
            # Check attention is always last and order is fine
199
            inputs_dict["output_attentions"] = True
Joseph Liu's avatar
Joseph Liu committed
200
            inputs_dict["output_hidden_states"] = True
201
202
203
204
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
205
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
206
207
208
209
            self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))

            self_attentions = outputs[-1]
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
210
211
212
213
214
215
216
217
218
219
            if chunk_length is not None:
                self.assertListEqual(
                    list(self_attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(self_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
thomwolf's avatar
thomwolf committed
220

Patrick von Platen's avatar
Patrick von Platen committed
221
    def test_torchscript(self):
222
223
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
224

Patrick von Platen's avatar
Patrick von Platen committed
225
    def test_torchscript_output_attentions(self):
226
227
228
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_attentions = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
229

Patrick von Platen's avatar
Patrick von Platen committed
230
    def test_torchscript_output_hidden_state(self):
231
232
233
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
234

235
    def _create_and_check_torchscript(self, config, inputs_dict):
Patrick von Platen's avatar
Patrick von Platen committed
236
        if not self.test_torchscript:
237
            return
238

239
240
241
242
243
244
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        configs_no_init.torchscript = True
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
245
            inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"]  # Let's keep only input_ids
thomwolf's avatar
thomwolf committed
246

247
248
249
250
            try:
                traced_gpt2 = torch.jit.trace(model, inputs)
            except RuntimeError:
                self.fail("Couldn't trace module.")
thomwolf's avatar
thomwolf committed
251

252
            with tempfile.TemporaryDirectory() as tmp_dir_name:
253
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
thomwolf's avatar
thomwolf committed
254

255
256
257
258
                try:
                    torch.jit.save(traced_gpt2, pt_file_name)
                except Exception:
                    self.fail("Couldn't save module.")
thomwolf's avatar
thomwolf committed
259

260
261
262
263
                try:
                    loaded_model = torch.jit.load(pt_file_name)
                except Exception:
                    self.fail("Couldn't load module.")
LysandreJik's avatar
LysandreJik committed
264

265
266
            model.to(torch_device)
            model.eval()
thomwolf's avatar
thomwolf committed
267

268
269
            loaded_model.to(torch_device)
            loaded_model.eval()
thomwolf's avatar
thomwolf committed
270

271
272
273
274
            model_state_dict = model.state_dict()
            loaded_model_state_dict = loaded_model.state_dict()

            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
thomwolf's avatar
thomwolf committed
275

276
            models_equal = True
277
278
            for layer_name, p1 in model_state_dict.items():
                p2 = loaded_model_state_dict[layer_name]
279
280
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False
thomwolf's avatar
thomwolf committed
281

282
            self.assertTrue(models_equal)
thomwolf's avatar
thomwolf committed
283

Patrick von Platen's avatar
Patrick von Platen committed
284
285
    def test_headmasking(self):
        if not self.test_head_masking:
286
            return
287

288
289
290
        global_rng.seed(42)
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        global_rng.seed()
LysandreJik's avatar
LysandreJik committed
291

292
        inputs_dict["output_attentions"] = True
293
294
295
296
297
298
        config.output_hidden_states = True
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
LysandreJik's avatar
LysandreJik committed
299

300
301
302
            # Prepare head_mask
            # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
            head_mask = torch.ones(
303
                self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
304
305
306
307
            )
            head_mask[0, 0] = 0
            head_mask[-1, :-1] = 0
            head_mask.requires_grad_(requires_grad=True)
308
            inputs = self._prepare_for_class(inputs_dict, model_class).copy()
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            inputs["head_mask"] = head_mask

            outputs = model(**inputs)

            # Test that we can get a gradient back for importance score computation
            output = sum(t.sum() for t in outputs[0])
            output = output.sum()
            output.backward()
            multihead_outputs = head_mask.grad

            attentions = outputs[-1]

            # Remove Nan
            for t in attentions:
                self.assertLess(
                    torch.sum(torch.isnan(t)), t.numel() / 4
                )  # Check we don't have more than 25% nans (arbitrary)
            attentions = [
                t.masked_fill(torch.isnan(t), 0.0) for t in attentions
            ]  # remove them (the test is less complete)

            self.assertIsNotNone(multihead_outputs)
            self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
            self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
            self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)

Patrick von Platen's avatar
Patrick von Platen committed
338
339
    def test_head_pruning(self):
        if not self.test_pruning:
340
341
342
            return

        for model_class in self.all_model_classes:
343
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
344

345
346
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
347

348
            inputs_dict["output_attentions"] = True
349
350
351
352
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
353
354
355
356
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
357
358
            model.prune_heads(heads_to_prune)
            with torch.no_grad():
359
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
360

361
            attentions = outputs[-1]
362

363
364
365
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
LysandreJik's avatar
LysandreJik committed
366

Patrick von Platen's avatar
Patrick von Platen committed
367
368
    def test_head_pruning_save_load_from_pretrained(self):
        if not self.test_pruning:
369
            return
LysandreJik's avatar
LysandreJik committed
370

371
        for model_class in self.all_model_classes:
372
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
373
374
375

            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
376

377
            inputs_dict["output_attentions"] = True
378
379
380
381
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
382
383
384
385
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
386
            model.prune_heads(heads_to_prune)
387

388
            with tempfile.TemporaryDirectory() as temp_dir_name:
389
390
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
391
                model.to(torch_device)
392

393
            with torch.no_grad():
394
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
395
396
397
398
            attentions = outputs[-1]
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
399

Patrick von Platen's avatar
Patrick von Platen committed
400
401
    def test_head_pruning_save_load_from_config_init(self):
        if not self.test_pruning:
402
            return
403

404
        for model_class in self.all_model_classes:
405
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
406

407
408
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
409

410
            inputs_dict["output_attentions"] = True
411
            config.output_hidden_states = False
412

413
414
415
416
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
417
            config.pruned_heads = heads_to_prune
418

419
420
421
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
422

423
            with torch.no_grad():
424
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
425
            attentions = outputs[-1]
426

427
428
429
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
430

Patrick von Platen's avatar
Patrick von Platen committed
431
432
    def test_head_pruning_integration(self):
        if not self.test_pruning:
433
            return
434

435
        for model_class in self.all_model_classes:
436
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
437

438
439
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
440

441
            inputs_dict["output_attentions"] = True
442
            config.output_hidden_states = False
443

444
445
            heads_to_prune = {0: [0], 1: [1, 2]}
            config.pruned_heads = heads_to_prune
446

447
448
449
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
450

451
            with torch.no_grad():
452
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
453
            attentions = outputs[-1]
454

455
456
457
458
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
459

460
            with tempfile.TemporaryDirectory() as temp_dir_name:
461
462
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
463
                model.to(torch_device)
thomwolf's avatar
thomwolf committed
464

465
            with torch.no_grad():
466
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
467
            attentions = outputs[-1]
LysandreJik's avatar
LysandreJik committed
468

469
470
471
472
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
473

474
475
            heads_to_prune = {0: [0], 2: [1, 2]}
            model.prune_heads(heads_to_prune)
476

477
            with torch.no_grad():
478
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
479
            attentions = outputs[-1]
480

481
482
483
484
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
485

486
            self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
thomwolf's avatar
thomwolf committed
487

Patrick von Platen's avatar
Patrick von Platen committed
488
    def test_hidden_states_output(self):
Joseph Liu's avatar
Joseph Liu committed
489
        def check_hidden_states_output(inputs_dict, config, model_class):
490
            model = model_class(config)
491
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
492
            model.eval()
Joseph Liu's avatar
Joseph Liu committed
493

thomwolf's avatar
thomwolf committed
494
            with torch.no_grad():
495
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
496
            hidden_states = outputs[-1]
Patrick von Platen's avatar
Patrick von Platen committed
497

Joseph Liu's avatar
Joseph Liu committed
498
            self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
Patrick von Platen's avatar
Patrick von Platen committed
499
500
501
502
503
504
505
            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
                if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
                    seq_length = seq_length * self.model_tester.chunk_length
            else:
                seq_length = self.model_tester.seq_length

506
            self.assertListEqual(
Patrick von Platen's avatar
Patrick von Platen committed
507
                list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
508
            )
thomwolf's avatar
thomwolf committed
509

Joseph Liu's avatar
Joseph Liu committed
510
511
512
513
514
515
516
517
518
519
520
521
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(inputs_dict, config, model_class)

            # check that output_hidden_states also work using config
            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True

            check_hidden_states_output(inputs_dict, config, model_class)

Patrick von Platen's avatar
Patrick von Platen committed
522
    def test_resize_tokens_embeddings(self):
523
        (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
Patrick von Platen's avatar
Patrick von Platen committed
524
        if not self.test_resize_embeddings:
525
526
527
528
529
            return

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config)
530
            model.to(torch_device)
531

Patrick von Platen's avatar
Patrick von Platen committed
532
533
534
            if self.model_tester.is_training is False:
                model.eval()

535
536
537
538
539
540
541
542
543
544
            model_vocab_size = config.vocab_size
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
545
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
546
            model(**self._prepare_for_class(inputs_dict, model_class))
547
548
549
550
551
552
553

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)

554
555
556
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
557
            model(**self._prepare_for_class(inputs_dict, model_class))
558

559
560
561
562
563
564
565
566
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False

            self.assertTrue(models_equal)

Patrick von Platen's avatar
Patrick von Platen committed
567
    def test_model_common_attributes(self):
568
569
570
571
572
573
574
575
576
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
            model.set_input_embeddings(torch.nn.Embedding(10, 10))
            x = model.get_output_embeddings()
            self.assertTrue(x is None or isinstance(x, torch.nn.Linear))

577
    def test_correct_missing_keys(self):
578
579
        if not self.test_missing_keys:
            return
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            base_model_prefix = model.base_model_prefix

            if hasattr(model, base_model_prefix):
                with tempfile.TemporaryDirectory() as temp_dir_name:
                    model.base_model.save_pretrained(temp_dir_name)
                    model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)

                    with self.subTest(msg="Missing keys for {}".format(model.__class__.__name__)):
                        self.assertGreater(len(loading_info["missing_keys"]), 0)

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    def test_tie_model_weights(self):
        if not self.test_torchscript:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_same_values(layer_1, layer_2):
            equal = True
            for p1, p2 in zip(layer_1.weight, layer_2.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    equal = False
            return equal

        for model_class in self.all_model_classes:
            config.torchscript = True
            model_not_tied = model_class(config)
            if model_not_tied.get_output_embeddings() is None:
                continue

            config_tied = copy.deepcopy(config)
            config_tied.torchscript = False
            model_tied = model_class(config_tied)
            params_tied = list(model_tied.parameters())
            # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # embeddings.weight.data.div_(2)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # decoding.weight.data.div_(4)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # Check that after resize they remain tied.
            model_tied.resize_token_embeddings(config.vocab_size + 10)
            params_tied_2 = list(model_tied.parameters())
            self.assertEqual(len(params_tied_2), len(params_tied))

            # decoding.weight.data.mul_(20)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
            # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

Patrick von Platen's avatar
Patrick von Platen committed
642
    def test_inputs_embeds(self):
Sam Shleifer's avatar
Sam Shleifer committed
643

644
645
646
647
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
648
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
649
            model.eval()
650

651
652
653
654
655
656
657
658
659
660
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
            if not self.is_encoder_decoder:
                input_ids = inputs["input_ids"]
                del inputs["input_ids"]
            else:
                encoder_input_ids = inputs["input_ids"]
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
                del inputs["input_ids"]
                inputs.pop("decoder_input_ids", None)

661
662
            wte = model.get_input_embeddings()
            if not self.is_encoder_decoder:
663
                inputs["inputs_embeds"] = wte(input_ids)
664
            else:
665
666
                inputs["inputs_embeds"] = wte(encoder_input_ids)
                inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
667

thomwolf's avatar
thomwolf committed
668
            with torch.no_grad():
669
                model(**inputs)
670

671
    def test_lm_head_model_random_no_beam_search_generate(self):
672
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
673
        input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
674

Patrick von Platen's avatar
Patrick von Platen committed
675
676
677
        # make sure that input_ids is at most of size 15
        input_ids = input_ids[..., :15]

678
        # iterate over all generative models
679
        for model_class in self.all_generative_model_classes:
680
            model = model_class(config).to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
681
            model.eval()
682
683

            if config.bos_token_id is None:
684
                # if bos token id is not defined, model needs input_ids
685
                with self.assertRaises(AssertionError):
686
                    model.generate(do_sample=True, max_length=5)
687
                # num_return_sequences = 1
688
                self._check_generated_ids(model.generate(input_ids, do_sample=True))
689
            else:
690
                # num_return_sequences = 1
691
                self._check_generated_ids(model.generate(do_sample=True, max_length=5))
692

693
            with self.assertRaises(AssertionError):
694
                # generating multiple sequences when no beam search generation
695
696
697
                # is not allowed as it would always generate the same sequences
                model.generate(input_ids, do_sample=False, num_return_sequences=2)

698
699
            # num_return_sequences > 1, sample
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
700
701

            # check bad words tokens language generation
702
            # create list of 1-seq bad token and list of 2-seq of bad tokens
703
704
705
706
            bad_words_ids = [
                self._generate_random_bad_tokens(1, model.config),
                self._generate_random_bad_tokens(2, model.config),
            ]
707
            output_tokens = model.generate(
708
                input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
709
            )
710
            # only count generated tokens
711
712
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
713

714
715
    def test_lm_head_model_random_beam_search_generate(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
716
717
718
        input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
            torch_device
        )
719

Patrick von Platen's avatar
Patrick von Platen committed
720
721
722
        # make sure that input_ids is at most of size 15
        input_ids = input_ids[..., :15]

723
        for model_class in self.all_generative_model_classes:
724
            model = model_class(config).to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
725
            model.eval()
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744

            if config.bos_token_id is None:
                # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
                self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
            else:
                # num_return_sequences = 1
                self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))

            with self.assertRaises(AssertionError):
                # generating more sequences than having beams leads is not possible
                model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)

            # num_return_sequences > 1, sample
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
            # num_return_sequences > 1, greedy
            self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))

            # check bad words tokens language generation
            # create list of 1-seq bad token and list of 2-seq of bad tokens
745
746
747
748
            bad_words_ids = [
                self._generate_random_bad_tokens(1, model.config),
                self._generate_random_bad_tokens(2, model.config),
            ]
749
            output_tokens = model.generate(
750
                input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
751
            )
752
            # only count generated tokens
753
754
755
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))

756
    def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
757
        # special tokens cannot be bad tokens
758
        special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
759
760
761
        # create random bad tokens that are not special tokens
        bad_tokens = []
        while len(bad_tokens) < num_bad_tokens:
762
            token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
763
764
765
766
            if token not in special_tokens:
                bad_tokens.append(token)
        return bad_tokens

767
    def _check_generated_ids(self, output_ids):
768
769
770
771
        for token_id in output_ids[0].tolist():
            self.assertGreaterEqual(token_id, 0)
            self.assertLess(token_id, self.model_tester.vocab_size)

772
773
774
775
776
777
778
779
780
781
782
783
    def _check_match_tokens(self, generated_ids, bad_words_ids):
        # for all bad word tokens
        for bad_word_ids in bad_words_ids:
            # for all slices in batch
            for generated_ids_slice in generated_ids:
                # for all word idx
                for i in range(len(bad_word_ids), len(generated_ids_slice)):
                    # if tokens match
                    if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
                        return True
        return False

784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    @require_multigpu
    def test_multigpu_data_parallel_forward(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        # some params shouldn't be scattered by nn.DataParallel
        # so just remove them if they are present.
        blacklist_non_batched_params = ["head_mask"]
        for k in blacklist_non_batched_params:
            inputs_dict.pop(k, None)

        # move input tensors to cuda:O
        for k, v in inputs_dict.items():
            if torch.is_tensor(v):
                inputs_dict[k] = v.to(0)

        for model_class in self.all_model_classes:
            model = model_class(config=config)
            model.to(0)
            model.eval()

            # Wrap model in nn.DataParallel
            model = torch.nn.DataParallel(model)
            with torch.no_grad():
807
                _ = model(**self._prepare_for_class(inputs_dict, model_class))
808

809

810
global_rng = random.Random()
thomwolf's avatar
thomwolf committed
811
812


thomwolf's avatar
thomwolf committed
813
def ids_tensor(shape, vocab_size, rng=None, name=None):
814
    #  Creates a random int32 tensor of the shape within the vocab size
thomwolf's avatar
thomwolf committed
815
    if rng is None:
816
        rng = global_rng
thomwolf's avatar
thomwolf committed
817

thomwolf's avatar
thomwolf committed
818
819
820
    total_dims = 1
    for dim in shape:
        total_dims *= dim
thomwolf's avatar
thomwolf committed
821

thomwolf's avatar
thomwolf committed
822
823
824
    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))
thomwolf's avatar
thomwolf committed
825

826
    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
827
828


829
def floats_tensor(shape, scale=1.0, rng=None, name=None):
Patrick von Platen's avatar
Patrick von Platen committed
830
    """Creates a random float32 tensor"""
831
832
833
834
835
836
837
838
839
840
841
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

842
    return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
843
844


845
@require_torch
thomwolf's avatar
thomwolf committed
846
class ModelUtilsTest(unittest.TestCase):
847
    @slow
Patrick von Platen's avatar
Patrick von Platen committed
848
    def test_model_from_pretrained(self):
849
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
thomwolf's avatar
thomwolf committed
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
            config = BertConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, PretrainedConfig)

            model = BertModel.from_pretrained(model_name)
            model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, PreTrainedModel)
            for value in loading_info.values():
                self.assertEqual(len(value), 0)

            config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
            model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
            self.assertEqual(model.config.output_hidden_states, True)
            self.assertEqual(model.config, config)
865
866
867
868
869
870


@require_torch
class UtilsFunctionsTest(unittest.TestCase):

    # tests whether the top_k_top_p function behaves as expected
Patrick von Platen's avatar
Patrick von Platen committed
871
    def test_top_k_top_p_filtering(self):
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
        logits = torch.tensor(
            [
                [
                    8.2220991,  # 3rd highest value; idx. 0
                    -0.5620044,
                    5.23229752,
                    4.0386393,
                    -6.8798378,
                    -0.54785802,
                    -3.2012153,
                    2.92777176,
                    1.88171953,
                    7.35341276,  # 5th highest value; idx. 9
                    8.43207833,  # 2nd highest value; idx. 10
                    -9.85711836,
                    -5.96209236,
                    -1.13039161,
                    -7.1115294,
                    -0.8369633,
                    -5.3186408,
                    7.06427407,
                    0.81369344,
                    -0.82023817,
                    -5.9179796,
                    0.58813443,
                    -6.99778438,
                    4.71551189,
                    -0.18771637,
                    7.44020759,  # 4th highest value; idx. 25
                    9.38450987,  # 1st highest value; idx. 26
                    2.12662941,
                    -9.32562038,
                    2.35652522,
                ],  # cummulative prob of 5 highest values <= 0.6
                [
                    0.58425518,
                    4.53139238,
                    -5.57510464,
                    -6.28030699,
                    -7.19529503,
                    -4.02122551,
                    1.39337037,
                    -6.06707057,
                    1.59480517,
                    -9.643119,
                    0.03907799,
                    0.67231762,
                    -8.88206726,
                    6.27115922,  # 4th highest value; idx. 13
                    2.28520723,
                    4.82767506,
                    4.30421368,
                    8.8275313,  # 2nd highest value; idx. 17
                    5.44029958,  # 5th highest value; idx. 18
                    -4.4735794,
                    7.38579536,  # 3rd highest value; idx. 20
                    -2.91051663,
                    2.61946077,
                    -2.5674762,
                    -9.48959302,
                    -4.02922645,
                    -1.35416918,
                    9.67702323,  # 1st highest value; idx. 27
                    -5.89478553,
                    1.85370467,
                ],  # cummulative prob of 5 highest values <= 0.6
            ],
            dtype=torch.float,
            device=torch_device,
        )

        non_inf_expected_idx = torch.tensor(
            [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
            dtype=torch.long,
            device=torch_device,
        )  # expected non filtered idx as noted above

        non_inf_expected_output = torch.tensor(
            [
                8.2221,
                7.3534,
                8.4321,
                7.4402,
                9.3845,
                6.2712,
                8.8275,
                5.4403,
                7.3858,
                9.6770,
            ],  # expected non filtered values as noted above
            dtype=torch.float,
            device=torch_device,
        )

        output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
        non_inf_output = output[output != -float("inf")].to(device=torch_device)
        non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)

        self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
        self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))