test_modeling_common.py 40.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
import copy
Aymeric Augustin's avatar
Aymeric Augustin committed
17
import logging
18
import os.path
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import random
20
import tempfile
thomwolf's avatar
thomwolf committed
21
import unittest
22
from typing import List
thomwolf's avatar
thomwolf committed
23

24
from transformers import is_torch_available
25

26
from .utils import require_multigpu, require_torch, slow, torch_device
27

Aymeric Augustin's avatar
Aymeric Augustin committed
28

29
if is_torch_available():
thomwolf's avatar
thomwolf committed
30
    import torch
31
    import numpy as np
thomwolf's avatar
thomwolf committed
32

33
34
35
36
37
38
    from transformers import (
        AdaptiveEmbedding,
        PretrainedConfig,
        PreTrainedModel,
        BertModel,
        BertConfig,
39
        BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
40
        MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
41
        top_k_top_p_filtering,
42
    )
thomwolf's avatar
thomwolf committed
43

44

45
46
47
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
48
        if "_range" in key or "_std" in key or "initializer_factor" in key:
Lysandre Debut's avatar
Lysandre Debut committed
49
            setattr(configs_no_init, key, 1e-10)
50
51
    return configs_no_init

thomwolf's avatar
thomwolf committed
52

53
54
55
56
57
@require_torch
class ModelTesterMixin:

    model_tester = None
    all_model_classes = ()
58
    all_generative_model_classes = ()
Patrick von Platen's avatar
Patrick von Platen committed
59
60
61
62
    test_torchscript = True
    test_pruning = True
    test_resize_embeddings = True
    test_head_masking = True
63
    test_missing_keys = True
64
65
    is_encoder_decoder = False

66
67
68
69
70
71
72
73
    def _prepare_for_class(self, inputs_dict, model_class):
        if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            return {
                k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
                for k, v in inputs_dict.items()
            }
        return inputs_dict

Patrick von Platen's avatar
Patrick von Platen committed
74
    def test_save_load(self):
75
76
77
78
79
80
81
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
82
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
83
            out_2 = outputs[0].cpu().numpy()
84
            out_2[np.isnan(out_2)] = 0
85

86
            with tempfile.TemporaryDirectory() as tmpdirname:
87
88
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)
89
                model.to(torch_device)
90
                with torch.no_grad():
91
                    after_outputs = model(**self._prepare_for_class(inputs_dict, model_class))
thomwolf's avatar
thomwolf committed
92

93
94
95
                # Make sure we don't have nans
                out_1 = after_outputs[0].cpu().numpy()
                out_1[np.isnan(out_1)] = 0
thomwolf's avatar
thomwolf committed
96
97
                max_diff = np.amax(np.abs(out_1 - out_2))
                self.assertLessEqual(max_diff, 1e-5)
98

Patrick von Platen's avatar
Patrick von Platen committed
99
    def test_initialization(self):
100
101
102
103
104
105
106
107
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        configs_no_init = _config_zero_init(config)
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.assertIn(
Lysandre Debut's avatar
Lysandre Debut committed
108
                        ((param.data.mean() * 1e9).round() / 1e9).item(),
109
110
111
                        [0.0, 1.0],
                        msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
                    )
thomwolf's avatar
thomwolf committed
112

Patrick von Platen's avatar
Patrick von Platen committed
113
    def test_determinism(self):
114
115
116
117
118
119
120
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
121
122
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
123
124
125
126
127
128
129
            out_1 = first.cpu().numpy()
            out_2 = second.cpu().numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

Patrick von Platen's avatar
Patrick von Platen committed
130
    def test_attention_outputs(self):
131
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
sshleifer's avatar
sshleifer committed
132
        seq_len = getattr(self.model_tester, "seq_length", None)
sshleifer's avatar
sshleifer committed
133
134
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
135
136
        decoder_key_length = getattr(self.model_tester, "key_length", decoder_seq_length)
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
Patrick von Platen's avatar
Patrick von Platen committed
137
138
139
        chunk_length = getattr(self.model_tester, "chunk_length", None)
        if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
140
141

        for model_class in self.all_model_classes:
142
            inputs_dict["output_attentions"] = True
143
144
145
146
147
            config.output_hidden_states = False
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
148
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
149
            attentions = outputs[-1]
150
151
152
153
154
155
156
157
158
159
160
161
            self.assertEqual(model.config.output_hidden_states, False)
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)

            # check that output_attentions also work using config
            del inputs_dict["output_attentions"]
            config.output_attentions = True
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
                outputs = model(**inputs_dict)
            attentions = outputs[-1]
162
163
            self.assertEqual(model.config.output_hidden_states, False)
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
164
165
166
167
168
169
170
171
172
173
174

            if chunk_length is not None:
                self.assertListEqual(
                    list(attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
175
            out_len = len(outputs)
thomwolf's avatar
thomwolf committed
176

177
            if self.is_encoder_decoder:
178
                correct_outlen = 4
Sam Shleifer's avatar
Sam Shleifer committed
179
                decoder_attention_idx = 1
180

181
                if "lm_labels" in inputs_dict:  # loss will come first
Sam Shleifer's avatar
Sam Shleifer committed
182
183
184
185
186
187
                    correct_outlen += 1  # compute loss
                    decoder_attention_idx += 1
                self.assertEqual(out_len, correct_outlen)

                decoder_attentions = outputs[decoder_attention_idx]
                self.assertIsInstance(decoder_attentions, (list, tuple))
188
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
thomwolf's avatar
thomwolf committed
189
                self.assertListEqual(
190
191
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
192
                )
thomwolf's avatar
thomwolf committed
193

194
            # Check attention is always last and order is fine
195
            inputs_dict["output_attentions"] = True
thomwolf's avatar
thomwolf committed
196
            config.output_hidden_states = True
197
198
199
200
            model = model_class(config)
            model.to(torch_device)
            model.eval()
            with torch.no_grad():
201
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
202
203
204
205
206
            self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
            self.assertEqual(model.config.output_hidden_states, True)

            self_attentions = outputs[-1]
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
207
208
209
210
211
212
213
214
215
216
            if chunk_length is not None:
                self.assertListEqual(
                    list(self_attentions[0].shape[-4:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
                )
            else:
                self.assertListEqual(
                    list(self_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )
thomwolf's avatar
thomwolf committed
217

Patrick von Platen's avatar
Patrick von Platen committed
218
    def test_torchscript(self):
219
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
thomwolf's avatar
thomwolf committed
220

221
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
222

Patrick von Platen's avatar
Patrick von Platen committed
223
    def test_torchscript_output_attentions(self):
224
225
226
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_attentions = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
227

Patrick von Platen's avatar
Patrick von Platen committed
228
    def test_torchscript_output_hidden_state(self):
229
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
230

231
232
        config.output_hidden_states = True
        self._create_and_check_torchscript(config, inputs_dict)
thomwolf's avatar
thomwolf committed
233

234
    def _create_and_check_torchscript(self, config, inputs_dict):
Patrick von Platen's avatar
Patrick von Platen committed
235
        if not self.test_torchscript:
236
            return
237

238
239
240
241
242
243
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        configs_no_init.torchscript = True
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
244
            inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"]  # Let's keep only input_ids
thomwolf's avatar
thomwolf committed
245

246
247
248
249
            try:
                traced_gpt2 = torch.jit.trace(model, inputs)
            except RuntimeError:
                self.fail("Couldn't trace module.")
thomwolf's avatar
thomwolf committed
250

251
            with tempfile.TemporaryDirectory() as tmp_dir_name:
252
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
thomwolf's avatar
thomwolf committed
253

254
255
256
257
                try:
                    torch.jit.save(traced_gpt2, pt_file_name)
                except Exception:
                    self.fail("Couldn't save module.")
thomwolf's avatar
thomwolf committed
258

259
260
261
262
                try:
                    loaded_model = torch.jit.load(pt_file_name)
                except Exception:
                    self.fail("Couldn't load module.")
LysandreJik's avatar
LysandreJik committed
263

264
265
            model.to(torch_device)
            model.eval()
thomwolf's avatar
thomwolf committed
266

267
268
            loaded_model.to(torch_device)
            loaded_model.eval()
thomwolf's avatar
thomwolf committed
269

270
271
272
273
            model_state_dict = model.state_dict()
            loaded_model_state_dict = loaded_model.state_dict()

            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
thomwolf's avatar
thomwolf committed
274

275
            models_equal = True
276
277
            for layer_name, p1 in model_state_dict.items():
                p2 = loaded_model_state_dict[layer_name]
278
279
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False
thomwolf's avatar
thomwolf committed
280

281
            self.assertTrue(models_equal)
thomwolf's avatar
thomwolf committed
282

Patrick von Platen's avatar
Patrick von Platen committed
283
284
    def test_headmasking(self):
        if not self.test_head_masking:
285
            return
286

287
288
289
        global_rng.seed(42)
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        global_rng.seed()
LysandreJik's avatar
LysandreJik committed
290

291
        inputs_dict["output_attentions"] = True
292
293
294
295
296
297
        config.output_hidden_states = True
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
        for model_class in self.all_model_classes:
            model = model_class(config=configs_no_init)
            model.to(torch_device)
            model.eval()
LysandreJik's avatar
LysandreJik committed
298

299
300
301
            # Prepare head_mask
            # Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
            head_mask = torch.ones(
302
                self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
303
304
305
306
            )
            head_mask[0, 0] = 0
            head_mask[-1, :-1] = 0
            head_mask.requires_grad_(requires_grad=True)
307
            inputs = self._prepare_for_class(inputs_dict, model_class).copy()
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            inputs["head_mask"] = head_mask

            outputs = model(**inputs)

            # Test that we can get a gradient back for importance score computation
            output = sum(t.sum() for t in outputs[0])
            output = output.sum()
            output.backward()
            multihead_outputs = head_mask.grad

            attentions = outputs[-1]

            # Remove Nan
            for t in attentions:
                self.assertLess(
                    torch.sum(torch.isnan(t)), t.numel() / 4
                )  # Check we don't have more than 25% nans (arbitrary)
            attentions = [
                t.masked_fill(torch.isnan(t), 0.0) for t in attentions
            ]  # remove them (the test is less complete)

            self.assertIsNotNone(multihead_outputs)
            self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
            self.assertAlmostEqual(attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
            self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
            self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)

Patrick von Platen's avatar
Patrick von Platen committed
337
338
    def test_head_pruning(self):
        if not self.test_pruning:
339
340
341
            return

        for model_class in self.all_model_classes:
342
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
343

344
345
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
346

347
            inputs_dict["output_attentions"] = True
348
349
350
351
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
352
353
354
355
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
356
357
            model.prune_heads(heads_to_prune)
            with torch.no_grad():
358
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
359

360
            attentions = outputs[-1]
361

362
363
364
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
LysandreJik's avatar
LysandreJik committed
365

Patrick von Platen's avatar
Patrick von Platen committed
366
367
    def test_head_pruning_save_load_from_pretrained(self):
        if not self.test_pruning:
368
            return
LysandreJik's avatar
LysandreJik committed
369

370
        for model_class in self.all_model_classes:
371
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
372
373
374

            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
375

376
            inputs_dict["output_attentions"] = True
377
378
379
380
            config.output_hidden_states = False
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
381
382
383
384
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
385
            model.prune_heads(heads_to_prune)
386

387
            with tempfile.TemporaryDirectory() as temp_dir_name:
388
389
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
390
                model.to(torch_device)
391

392
            with torch.no_grad():
393
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
394
395
396
397
            attentions = outputs[-1]
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
398

Patrick von Platen's avatar
Patrick von Platen committed
399
400
    def test_head_pruning_save_load_from_config_init(self):
        if not self.test_pruning:
401
            return
402

403
        for model_class in self.all_model_classes:
404
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
405

406
407
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
408

409
            inputs_dict["output_attentions"] = True
410
            config.output_hidden_states = False
411

412
413
414
415
            heads_to_prune = {
                0: list(range(1, self.model_tester.num_attention_heads)),
                -1: [0],
            }
416
            config.pruned_heads = heads_to_prune
417

418
419
420
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
421

422
            with torch.no_grad():
423
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
424
            attentions = outputs[-1]
425

426
427
428
            self.assertEqual(attentions[0].shape[-3], 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
429

Patrick von Platen's avatar
Patrick von Platen committed
430
431
    def test_head_pruning_integration(self):
        if not self.test_pruning:
432
            return
433

434
        for model_class in self.all_model_classes:
435
            (config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
436

437
438
            if "head_mask" in inputs_dict:
                del inputs_dict["head_mask"]
439

440
            inputs_dict["output_attentions"] = True
441
            config.output_hidden_states = False
442

443
444
            heads_to_prune = {0: [0], 1: [1, 2]}
            config.pruned_heads = heads_to_prune
445

446
447
448
            model = model_class(config=config)
            model.to(torch_device)
            model.eval()
449

450
            with torch.no_grad():
451
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
452
            attentions = outputs[-1]
453

454
455
456
457
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
458

459
            with tempfile.TemporaryDirectory() as temp_dir_name:
460
461
                model.save_pretrained(temp_dir_name)
                model = model_class.from_pretrained(temp_dir_name)
462
                model.to(torch_device)
thomwolf's avatar
thomwolf committed
463

464
            with torch.no_grad():
465
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
466
            attentions = outputs[-1]
LysandreJik's avatar
LysandreJik committed
467

468
469
470
471
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
thomwolf's avatar
thomwolf committed
472

473
474
            heads_to_prune = {0: [0], 2: [1, 2]}
            model.prune_heads(heads_to_prune)
475

476
            with torch.no_grad():
477
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
478
            attentions = outputs[-1]
479

480
481
482
483
            self.assertEqual(attentions[0].shape[-3], self.model_tester.num_attention_heads - 1)
            self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads - 2)
            self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
484

485
            self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
thomwolf's avatar
thomwolf committed
486

Patrick von Platen's avatar
Patrick von Platen committed
487
    def test_hidden_states_output(self):
488
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
thomwolf's avatar
thomwolf committed
489

490
491
492
        for model_class in self.all_model_classes:
            config.output_hidden_states = True
            model = model_class(config)
493
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
494
            model.eval()
thomwolf's avatar
thomwolf committed
495
            with torch.no_grad():
496
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
497
498
499
            hidden_states = outputs[-1]
            self.assertEqual(model.config.output_hidden_states, True)
            self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
Patrick von Platen's avatar
Patrick von Platen committed
500
501
502
503
504
505
506
507

            if hasattr(self.model_tester, "encoder_seq_length"):
                seq_length = self.model_tester.encoder_seq_length
                if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
                    seq_length = seq_length * self.model_tester.chunk_length
            else:
                seq_length = self.model_tester.seq_length

508
            self.assertListEqual(
Patrick von Platen's avatar
Patrick von Platen committed
509
                list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
510
            )
thomwolf's avatar
thomwolf committed
511

Patrick von Platen's avatar
Patrick von Platen committed
512
    def test_resize_tokens_embeddings(self):
513
        (original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
Patrick von Platen's avatar
Patrick von Platen committed
514
        if not self.test_resize_embeddings:
515
516
517
518
519
            return

        for model_class in self.all_model_classes:
            config = copy.deepcopy(original_config)
            model = model_class(config)
520
            model.to(torch_device)
521

Patrick von Platen's avatar
Patrick von Platen committed
522
523
524
            if self.model_tester.is_training is False:
                model.eval()

525
526
527
528
529
530
531
532
533
534
            model_vocab_size = config.vocab_size
            # Retrieve the embeddings and clone theme
            model_embed = model.resize_token_embeddings(model_vocab_size)
            cloned_embeddings = model_embed.weight.clone()

            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
535
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
536
            model(**self._prepare_for_class(inputs_dict, model_class))
537
538
539
540
541
542
543

            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
            # Check that it actually resizes the embeddings matrix
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)

544
545
546
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
            # Input ids should be clamped to the maximum size of the vocabulary
            inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
547
            model(**self._prepare_for_class(inputs_dict, model_class))
548

549
550
551
552
553
554
555
556
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
            models_equal = True
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    models_equal = False

            self.assertTrue(models_equal)

Patrick von Platen's avatar
Patrick von Platen committed
557
    def test_model_common_attributes(self):
558
559
560
561
562
563
564
565
566
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
            model.set_input_embeddings(torch.nn.Embedding(10, 10))
            x = model.get_output_embeddings()
            self.assertTrue(x is None or isinstance(x, torch.nn.Linear))

567
    def test_correct_missing_keys(self):
568
569
        if not self.test_missing_keys:
            return
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
            base_model_prefix = model.base_model_prefix

            if hasattr(model, base_model_prefix):
                with tempfile.TemporaryDirectory() as temp_dir_name:
                    model.base_model.save_pretrained(temp_dir_name)
                    model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)

                    with self.subTest(msg="Missing keys for {}".format(model.__class__.__name__)):
                        self.assertGreater(len(loading_info["missing_keys"]), 0)

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    def test_tie_model_weights(self):
        if not self.test_torchscript:
            return

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_same_values(layer_1, layer_2):
            equal = True
            for p1, p2 in zip(layer_1.weight, layer_2.weight):
                if p1.data.ne(p2.data).sum() > 0:
                    equal = False
            return equal

        for model_class in self.all_model_classes:
            config.torchscript = True
            model_not_tied = model_class(config)
            if model_not_tied.get_output_embeddings() is None:
                continue

            params_not_tied = list(model_not_tied.parameters())

            config_tied = copy.deepcopy(config)
            config_tied.torchscript = False
            model_tied = model_class(config_tied)
            params_tied = list(model_tied.parameters())

            # Check that the embedding layer and decoding layer are the same in size and in value
            self.assertGreater(len(params_not_tied), len(params_tied))
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # embeddings.weight.data.div_(2)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # # Check that after modification, they remain the same.
            # decoding.weight.data.div_(4)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
            # self.assertTrue(check_same_values(embeddings, decoding))

            # Check that after resize they remain tied.
            model_tied.resize_token_embeddings(config.vocab_size + 10)
            params_tied_2 = list(model_tied.parameters())
            self.assertGreater(len(params_not_tied), len(params_tied))
            self.assertEqual(len(params_tied_2), len(params_tied))

            # decoding.weight.data.mul_(20)
            # # Check that the embedding layer and decoding layer are the same in size and in value
            # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
            # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

Patrick von Platen's avatar
Patrick von Platen committed
637
    def test_inputs_embeds(self):
Sam Shleifer's avatar
Sam Shleifer committed
638

639
640
641
642
643
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        if not self.is_encoder_decoder:
            input_ids = inputs_dict["input_ids"]
            del inputs_dict["input_ids"]
        else:
644
            encoder_input_ids = inputs_dict["input_ids"]
Sam Shleifer's avatar
Sam Shleifer committed
645
            decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
646
            del inputs_dict["input_ids"]
Sam Shleifer's avatar
Sam Shleifer committed
647
            inputs_dict.pop("decoder_input_ids", None)
648
649

        for model_class in self.all_model_classes:
650
651
            if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                continue
652
            model = model_class(config)
653
            model.to(torch_device)
thomwolf's avatar
thomwolf committed
654
            model.eval()
655
656
657
658
659

            wte = model.get_input_embeddings()
            if not self.is_encoder_decoder:
                inputs_dict["inputs_embeds"] = wte(input_ids)
            else:
660
                inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
661
662
                inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)

thomwolf's avatar
thomwolf committed
663
            with torch.no_grad():
664
                model(**inputs_dict)
665

666
    def test_lm_head_model_random_no_beam_search_generate(self):
667
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
668
        input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
669

Patrick von Platen's avatar
Patrick von Platen committed
670
671
672
        # make sure that input_ids is at most of size 15
        input_ids = input_ids[..., :15]

673
        # iterate over all generative models
674
        for model_class in self.all_generative_model_classes:
675
            model = model_class(config).to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
676
            model.eval()
677
678

            if config.bos_token_id is None:
679
                # if bos token id is not defined, model needs input_ids
680
                with self.assertRaises(AssertionError):
681
                    model.generate(do_sample=True, max_length=5)
682
                # num_return_sequences = 1
683
                self._check_generated_ids(model.generate(input_ids, do_sample=True))
684
            else:
685
                # num_return_sequences = 1
686
                self._check_generated_ids(model.generate(do_sample=True, max_length=5))
687

688
            with self.assertRaises(AssertionError):
689
                # generating multiple sequences when no beam search generation
690
691
692
                # is not allowed as it would always generate the same sequences
                model.generate(input_ids, do_sample=False, num_return_sequences=2)

693
694
            # num_return_sequences > 1, sample
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
695
696

            # check bad words tokens language generation
697
            # create list of 1-seq bad token and list of 2-seq of bad tokens
698
699
700
701
            bad_words_ids = [
                self._generate_random_bad_tokens(1, model.config),
                self._generate_random_bad_tokens(2, model.config),
            ]
702
            output_tokens = model.generate(
703
                input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
704
            )
705
            # only count generated tokens
706
707
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
708

709
710
    def test_lm_head_model_random_beam_search_generate(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
711
712
713
        input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
            torch_device
        )
714

Patrick von Platen's avatar
Patrick von Platen committed
715
716
717
        # make sure that input_ids is at most of size 15
        input_ids = input_ids[..., :15]

718
        for model_class in self.all_generative_model_classes:
719
            model = model_class(config).to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
720
            model.eval()
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739

            if config.bos_token_id is None:
                # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
                self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
            else:
                # num_return_sequences = 1
                self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))

            with self.assertRaises(AssertionError):
                # generating more sequences than having beams leads is not possible
                model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)

            # num_return_sequences > 1, sample
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
            # num_return_sequences > 1, greedy
            self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))

            # check bad words tokens language generation
            # create list of 1-seq bad token and list of 2-seq of bad tokens
740
741
742
743
            bad_words_ids = [
                self._generate_random_bad_tokens(1, model.config),
                self._generate_random_bad_tokens(2, model.config),
            ]
744
            output_tokens = model.generate(
745
                input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
746
            )
747
            # only count generated tokens
748
749
750
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))

751
    def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
752
        # special tokens cannot be bad tokens
753
        special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
754
755
756
        # create random bad tokens that are not special tokens
        bad_tokens = []
        while len(bad_tokens) < num_bad_tokens:
757
            token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
758
759
760
761
            if token not in special_tokens:
                bad_tokens.append(token)
        return bad_tokens

762
    def _check_generated_ids(self, output_ids):
763
764
765
766
        for token_id in output_ids[0].tolist():
            self.assertGreaterEqual(token_id, 0)
            self.assertLess(token_id, self.model_tester.vocab_size)

767
768
769
770
771
772
773
774
775
776
777
778
    def _check_match_tokens(self, generated_ids, bad_words_ids):
        # for all bad word tokens
        for bad_word_ids in bad_words_ids:
            # for all slices in batch
            for generated_ids_slice in generated_ids:
                # for all word idx
                for i in range(len(bad_word_ids), len(generated_ids_slice)):
                    # if tokens match
                    if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
                        return True
        return False

779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    @require_multigpu
    def test_multigpu_data_parallel_forward(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        # some params shouldn't be scattered by nn.DataParallel
        # so just remove them if they are present.
        blacklist_non_batched_params = ["head_mask"]
        for k in blacklist_non_batched_params:
            inputs_dict.pop(k, None)

        # move input tensors to cuda:O
        for k, v in inputs_dict.items():
            if torch.is_tensor(v):
                inputs_dict[k] = v.to(0)

        for model_class in self.all_model_classes:
            model = model_class(config=config)
            model.to(0)
            model.eval()

            # Wrap model in nn.DataParallel
            model = torch.nn.DataParallel(model)
            with torch.no_grad():
                _ = model(**inputs_dict)

804

805
global_rng = random.Random()
thomwolf's avatar
thomwolf committed
806
807


thomwolf's avatar
thomwolf committed
808
def ids_tensor(shape, vocab_size, rng=None, name=None):
809
    #  Creates a random int32 tensor of the shape within the vocab size
thomwolf's avatar
thomwolf committed
810
    if rng is None:
811
        rng = global_rng
thomwolf's avatar
thomwolf committed
812

thomwolf's avatar
thomwolf committed
813
814
815
    total_dims = 1
    for dim in shape:
        total_dims *= dim
thomwolf's avatar
thomwolf committed
816

thomwolf's avatar
thomwolf committed
817
818
819
    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))
thomwolf's avatar
thomwolf committed
820

821
    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
thomwolf's avatar
thomwolf committed
822
823


824
def floats_tensor(shape, scale=1.0, rng=None, name=None):
Patrick von Platen's avatar
Patrick von Platen committed
825
    """Creates a random float32 tensor"""
826
827
828
829
830
831
832
833
834
835
836
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

837
    return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
838
839


840
@require_torch
thomwolf's avatar
thomwolf committed
841
class ModelUtilsTest(unittest.TestCase):
842
    @slow
Patrick von Platen's avatar
Patrick von Platen committed
843
    def test_model_from_pretrained(self):
thomwolf's avatar
thomwolf committed
844
        logging.basicConfig(level=logging.INFO)
845
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
thomwolf's avatar
thomwolf committed
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
            config = BertConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, PretrainedConfig)

            model = BertModel.from_pretrained(model_name)
            model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, PreTrainedModel)
            for value in loading_info.values():
                self.assertEqual(len(value), 0)

            config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
            model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
            self.assertEqual(model.config.output_hidden_states, True)
            self.assertEqual(model.config, config)
861
862
863
864
865
866


@require_torch
class UtilsFunctionsTest(unittest.TestCase):

    # tests whether the top_k_top_p function behaves as expected
Patrick von Platen's avatar
Patrick von Platen committed
867
    def test_top_k_top_p_filtering(self):
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
        logits = torch.tensor(
            [
                [
                    8.2220991,  # 3rd highest value; idx. 0
                    -0.5620044,
                    5.23229752,
                    4.0386393,
                    -6.8798378,
                    -0.54785802,
                    -3.2012153,
                    2.92777176,
                    1.88171953,
                    7.35341276,  # 5th highest value; idx. 9
                    8.43207833,  # 2nd highest value; idx. 10
                    -9.85711836,
                    -5.96209236,
                    -1.13039161,
                    -7.1115294,
                    -0.8369633,
                    -5.3186408,
                    7.06427407,
                    0.81369344,
                    -0.82023817,
                    -5.9179796,
                    0.58813443,
                    -6.99778438,
                    4.71551189,
                    -0.18771637,
                    7.44020759,  # 4th highest value; idx. 25
                    9.38450987,  # 1st highest value; idx. 26
                    2.12662941,
                    -9.32562038,
                    2.35652522,
                ],  # cummulative prob of 5 highest values <= 0.6
                [
                    0.58425518,
                    4.53139238,
                    -5.57510464,
                    -6.28030699,
                    -7.19529503,
                    -4.02122551,
                    1.39337037,
                    -6.06707057,
                    1.59480517,
                    -9.643119,
                    0.03907799,
                    0.67231762,
                    -8.88206726,
                    6.27115922,  # 4th highest value; idx. 13
                    2.28520723,
                    4.82767506,
                    4.30421368,
                    8.8275313,  # 2nd highest value; idx. 17
                    5.44029958,  # 5th highest value; idx. 18
                    -4.4735794,
                    7.38579536,  # 3rd highest value; idx. 20
                    -2.91051663,
                    2.61946077,
                    -2.5674762,
                    -9.48959302,
                    -4.02922645,
                    -1.35416918,
                    9.67702323,  # 1st highest value; idx. 27
                    -5.89478553,
                    1.85370467,
                ],  # cummulative prob of 5 highest values <= 0.6
            ],
            dtype=torch.float,
            device=torch_device,
        )

        non_inf_expected_idx = torch.tensor(
            [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
            dtype=torch.long,
            device=torch_device,
        )  # expected non filtered idx as noted above

        non_inf_expected_output = torch.tensor(
            [
                8.2221,
                7.3534,
                8.4321,
                7.4402,
                9.3845,
                6.2712,
                8.8275,
                5.4403,
                7.3858,
                9.6770,
            ],  # expected non filtered values as noted above
            dtype=torch.float,
            device=torch_device,
        )

        output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
        non_inf_output = output[output != -float("inf")].to(device=torch_device)
        non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)

        self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
        self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))