test_modeling_tf_common.py 42.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16
17

import copy
18
import inspect
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import os
thomwolf's avatar
thomwolf committed
20
import random
Aymeric Augustin's avatar
Aymeric Augustin committed
21
import tempfile
22
import unittest
23
from importlib import import_module
24
from typing import List, Tuple
thomwolf's avatar
thomwolf committed
25

26
from transformers import is_tf_available, is_torch_available
Julien Plu's avatar
Julien Plu committed
27
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
28

Aymeric Augustin's avatar
Aymeric Augustin committed
29

30
if is_tf_available():
thomwolf's avatar
thomwolf committed
31
    import numpy as np
32
    import tensorflow as tf
33

34
    from transformers import (
35
36
        TF_MODEL_FOR_CAUSAL_LM_MAPPING,
        TF_MODEL_FOR_MASKED_LM_MAPPING,
37
        TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
38
        TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
39
        TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
40
41
        TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
        TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
42
43
44
        TFAdaptiveEmbedding,
        TFSharedEmbeddings,
        tf_top_k_top_p_filtering,
45
    )
46

Julien Chaumond's avatar
Julien Chaumond committed
47
48
49
50
51
52
53
54
55
56
57
58
59
    if _tf_gpu_memory_limit is not None:
        gpus = tf.config.list_physical_devices("GPU")
        for gpu in gpus:
            # Restrict TensorFlow to only allocate x GB of memory on the GPUs
            try:
                tf.config.experimental.set_virtual_device_configuration(
                    gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
                )
                logical_gpus = tf.config.experimental.list_logical_devices("GPU")
                print("Logical GPUs", logical_gpus)
            except RuntimeError as e:
                # Virtual devices must be set before GPUs have been initialized
                print(e)
thomwolf's avatar
thomwolf committed
60

61

thomwolf's avatar
thomwolf committed
62
63
64
def _config_zero_init(config):
    configs_no_init = copy.deepcopy(config)
    for key in configs_no_init.__dict__.keys():
65
        if "_range" in key or "_std" in key:
thomwolf's avatar
thomwolf committed
66
67
68
69
            setattr(configs_no_init, key, 0.0)
    return configs_no_init


70
71
@require_tf
class TFModelTesterMixin:
72

73
74
    model_tester = None
    all_model_classes = ()
75
    all_generative_model_classes = ()
76
77
    test_resize_embeddings = True
    is_encoder_decoder = False
78

79
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
80
81
        inputs_dict = copy.deepcopy(inputs_dict)

82
        if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
83
            inputs_dict = {
84
85
                k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
                if isinstance(v, tf.Tensor) and v.ndim > 0
86
87
88
                else v
                for k, v in inputs_dict.items()
            }
89
90
91

        if return_labels:
            if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
92
                inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
93
            elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
94
95
                inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
                inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
96
            elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
97
98
99
100
101
102
103
104
105
106
                inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in [
                *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
                *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
                *TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
                *TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
            ]:
                inputs_dict["labels"] = tf.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
                )
107
108
        return inputs_dict

109
110
    def test_initialization(self):
        pass
111

112
113
    def test_save_load(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
114

115
116
        for model_class in self.all_model_classes:
            model = model_class(config)
117
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
118

119
            with tempfile.TemporaryDirectory() as tmpdirname:
120
121
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)
122
                after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
123

124
                self.assert_outputs_same(after_outputs, outputs)
125

126
127
128
129
130
131
132
133
134
135
136
137
138
    def test_graph_mode(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            inputs = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config)

            @tf.function
            def run_in_graph_mode():
                return model(inputs)

            outputs = run_in_graph_mode()
            self.assertIsNotNone(outputs)

Julien Plu's avatar
Julien Plu committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    @slow
    def test_saved_model_with_hidden_states_output(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_hidden_states = True

        for model_class in self.all_model_classes:
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config)
            num_out = len(model(inputs_dict))
            model._saved_model_inputs_spec = None
            model._set_save_spec(inputs_dict)

            with tempfile.TemporaryDirectory() as tmpdirname:
                tf.saved_model.save(model, tmpdirname)
                model = tf.keras.models.load_model(tmpdirname)
                outputs = model(inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
155
156
                output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
                hidden_states = [t.numpy() for t in output]
Julien Plu's avatar
Julien Plu committed
157
158
159
                self.assertEqual(len(outputs), num_out)
                self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
                self.assertListEqual(
Lysandre's avatar
Lysandre committed
160
161
                    list(hidden_states[0].shape[-2:]),
                    [self.model_tester.seq_length, self.model_tester.hidden_size],
Julien Plu's avatar
Julien Plu committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
                )

    @slow
    def test_saved_model_with_attentions_output(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        config.output_attentions = True
        encoder_seq_length = (
            self.model_tester.encoder_seq_length
            if hasattr(self.model_tester, "encoder_seq_length")
            else self.model_tester.seq_length
        )
        encoder_key_length = (
            self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
        )

        for model_class in self.all_model_classes:
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            model = model_class(config)
            num_out = len(model(inputs_dict))
            model._saved_model_inputs_spec = None
            model._set_save_spec(inputs_dict)

            with tempfile.TemporaryDirectory() as tmpdirname:
                tf.saved_model.save(model, tmpdirname)
                model = tf.keras.models.load_model(tmpdirname)
                outputs = model(inputs_dict)
Sylvain Gugger's avatar
Sylvain Gugger committed
188
189
                output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
                attentions = [t.numpy() for t in output]
Julien Plu's avatar
Julien Plu committed
190
191
192
193
194
195
196
                self.assertEqual(len(outputs), num_out)
                self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
                self.assertListEqual(
                    list(attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
                )

197
198
199
200
201
202
203
204
    def test_keras_save_load(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        tf_main_layer_classes = set(
            module_member
            for model_class in self.all_model_classes
            for module in (import_module(model_class.__module__),)
            for module_member_name in dir(module)
205
            if module_member_name.endswith("MainLayer")
206
            for module_member in (getattr(module, module_member_name),)
207
208
209
            if isinstance(module_member, type)
            and tf.keras.layers.Layer in module_member.__bases__
            and getattr(module_member, "_keras_serializable", False)
210
211
        )
        for main_layer_class in tf_main_layer_classes:
Julien Plu's avatar
Julien Plu committed
212
213
214
215
            # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
            if "T5" in main_layer_class.__name__:
                # Take the same values than in TFT5ModelTester for this shared layer
                shared = TFSharedEmbeddings(99, 32, name="shared")
216
                config.use_cache = False
Julien Plu's avatar
Julien Plu committed
217
218
219
                main_layer = main_layer_class(config, embed_tokens=shared)
            else:
                main_layer = main_layer_class(config)
220
221
222
            symbolic_inputs = {
                name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
            }
Julien Plu's avatar
Julien Plu committed
223

224
225
226
227
228
229
            model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
            outputs = model(inputs_dict)

            with tempfile.TemporaryDirectory() as tmpdirname:
                filepath = os.path.join(tmpdirname, "keras_model.h5")
                model.save(filepath)
Julien Plu's avatar
Julien Plu committed
230
231
232
233
234
235
236
237
238
239
240
241
                if "T5" in main_layer_class.__name__:
                    model = tf.keras.models.load_model(
                        filepath,
                        custom_objects={
                            main_layer_class.__name__: main_layer_class,
                            "TFSharedEmbeddings": TFSharedEmbeddings,
                        },
                    )
                else:
                    model = tf.keras.models.load_model(
                        filepath, custom_objects={main_layer_class.__name__: main_layer_class}
                    )
242
243
244
245
246
247
                assert isinstance(model, tf.keras.Model)
                after_outputs = model(inputs_dict)
                self.assert_outputs_same(after_outputs, outputs)

    def assert_outputs_same(self, after_outputs, outputs):
        # Make sure we don't have nans
Julien Plu's avatar
Julien Plu committed
248
249
        if isinstance(after_outputs, tf.Tensor):
            out_1 = after_outputs.numpy()
Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
        elif isinstance(after_outputs, dict):
            out_1 = after_outputs[list(after_outputs.keys())[0]]
Julien Plu's avatar
Julien Plu committed
252
253
        else:
            out_1 = after_outputs[0].numpy()
254
        out_2 = outputs[0].numpy()
255
        self.assertEqual(out_1.shape, out_2.shape)
256
257
258
259
        out_1 = out_1[~np.isnan(out_1)]
        out_2 = out_2[~np.isnan(out_2)]
        max_diff = np.amax(np.abs(out_1 - out_2))
        self.assertLessEqual(max_diff, 1e-5)
260

261
262
263
    def test_pt_tf_model_equivalence(self):
        if not is_torch_available():
            return
thomwolf's avatar
thomwolf committed
264

265
        import torch
266

267
        import transformers
thomwolf's avatar
thomwolf committed
268

269
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
thomwolf's avatar
thomwolf committed
270

271
272
273
        for model_class in self.all_model_classes:
            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beggining
            pt_model_class = getattr(transformers, pt_model_class_name)
thomwolf's avatar
thomwolf committed
274

275
            config.output_hidden_states = True
276

277
278
            tf_model = model_class(config)
            pt_model = pt_model_class(config)
thomwolf's avatar
thomwolf committed
279

280
            # Check we can load pt model in tf and vice-versa with model => model functions
281

282
283
284
            tf_model = transformers.load_pytorch_model_in_tf2_model(
                tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
            )
285
            pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
286

287
288
289
            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
290
291
                (name, torch.from_numpy(key.numpy()).to(torch.long))
                for name, key in self._prepare_for_class(inputs_dict, model_class).items()
292
            )
293
294
295
296
            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

297
298
            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
299
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
300
301
            tf_hidden_states = tfo[0].numpy()
            pt_hidden_states = pto[0].numpy()
Lysandre's avatar
Lysandre committed
302

303
304
305
306
307
308
309
            tf_nans = np.copy(np.isnan(tf_hidden_states))
            pt_nans = np.copy(np.isnan(pt_hidden_states))

            pt_hidden_states[tf_nans] = 0
            tf_hidden_states[tf_nans] = 0
            pt_hidden_states[pt_nans] = 0
            tf_hidden_states[pt_nans] = 0
Lysandre's avatar
Lysandre committed
310

311
            max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
312
            # Debug info (remove when fixed)
313
            if max_diff >= 4e-2:
314
315
316
317
318
                print("===")
                print(model_class)
                print(config)
                print(inputs_dict)
                print(pt_inputs_dict)
319
            self.assertLessEqual(max_diff, 4e-2)
320
321

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
322
            with tempfile.TemporaryDirectory() as tmpdirname:
323
324
325
326
327
328
329
330
331
332
333
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
334
335
                (name, torch.from_numpy(key.numpy()).to(torch.long))
                for name, key in self._prepare_for_class(inputs_dict, model_class).items()
336
            )
337
338
339
340
            # need to rename encoder-decoder "inputs" for PyTorch
            if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
                pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")

341
342
            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
343
            tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
344
345
            tfo = tfo[0].numpy()
            pto = pto[0].numpy()
346
347
348
349
350
351
352
353
            tf_nans = np.copy(np.isnan(tfo))
            pt_nans = np.copy(np.isnan(pto))

            pto[tf_nans] = 0
            tfo[tf_nans] = 0
            pto[pt_nans] = 0
            tfo[pt_nans] = 0

354
            max_diff = np.amax(np.abs(tfo - pto))
sgugger's avatar
sgugger committed
355
            self.assertLessEqual(max_diff, 4e-2)
356
357
358
359
360
361
362
363
364

    def test_compile_tf_model(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")

        for model_class in self.all_model_classes:
365
366
367
368
369
            if self.is_encoder_decoder:
                input_ids = {
                    "decoder_input_ids": tf.keras.Input(
                        batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
                    ),
370
                    "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
371
372
373
374
375
376
                }
            elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                input_ids = tf.keras.Input(batch_shape=(4, 2, 2000), name="input_ids", dtype="int32")
            else:
                input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")

377
378
379
380
            # Prepare our model
            model = model_class(config)

            # Let's load it from the disk to be sure we can use pretrained weights
381
            with tempfile.TemporaryDirectory() as tmpdirname:
382
                outputs = model(self._prepare_for_class(inputs_dict, model_class))  # build the model
383
384
385
386
387
388
                model.save_pretrained(tmpdirname)
                model = model_class.from_pretrained(tmpdirname)

            outputs_dict = model(input_ids)
            hidden_states = outputs_dict[0]

389
            # Add a dense layer on top to test integration with other keras modules
390
391
392
393
394
395
396
397
398
399
400
            outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)

            # Compile extended model
            extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
            extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

    def test_keyword_and_dict_args(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
401
            outputs_dict = model(self._prepare_for_class(inputs_dict, model_class))
402

403
            inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
404
            input_ids = inputs_keywords.pop("input_ids", None)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            outputs_keywords = model(input_ids, **inputs_keywords)
            output_dict = outputs_dict[0].numpy()
            output_keywords = outputs_keywords[0].numpy()

            self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)

    def test_attention_outputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        decoder_seq_length = (
            self.model_tester.decoder_seq_length
            if hasattr(self.model_tester, "decoder_seq_length")
            else self.model_tester.seq_length
        )
        encoder_seq_length = (
            self.model_tester.encoder_seq_length
            if hasattr(self.model_tester, "encoder_seq_length")
            else self.model_tester.seq_length
        )
        decoder_key_length = (
            self.model_tester.key_length if hasattr(self.model_tester, "key_length") else decoder_seq_length
        )
        encoder_key_length = (
            self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
        )

        for model_class in self.all_model_classes:
432
            inputs_dict["output_attentions"] = True
433
434
            config.output_hidden_states = False
            model = model_class(config)
435
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
436
437
438
439
440
441
            attentions = [t.numpy() for t in outputs[-1]]
            self.assertEqual(model.config.output_hidden_states, False)
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
            self.assertListEqual(
                list(attentions[0].shape[-3:]),
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
442
            )
443
            out_len = len(outputs)
thomwolf's avatar
thomwolf committed
444

445
446
447
            if self.is_encoder_decoder:
                self.assertEqual(out_len % 2, 0)
                decoder_attentions = outputs[(out_len // 2) - 1]
448
                self.assertEqual(model.config.output_hidden_states, False)
449
                self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
450
                self.assertListEqual(
451
452
                    list(decoder_attentions[0].shape[-3:]),
                    [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
453
                )
thomwolf's avatar
thomwolf committed
454

455
456
            # Check that output attentions can also be changed via the config
            del inputs_dict["output_attentions"]
457
            config.output_attentions = True
458
            model = model_class(config)
459
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
460
461
462
463
464
465
466
467
468
469
            attentions = [t.numpy() for t in outputs[-1]]
            self.assertEqual(model.config.output_hidden_states, False)
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
            self.assertListEqual(
                list(attentions[0].shape[-3:]),
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
            )

            # Check attention is always last and order is fine
            inputs_dict["output_attentions"] = True
470
471
            config.output_hidden_states = True
            model = model_class(config)
472
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
473
474
475
476
477
478
479
480
481
            self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
            self.assertEqual(model.config.output_hidden_states, True)

            attentions = [t.numpy() for t in outputs[-1]]
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
            self.assertListEqual(
                list(attentions[0].shape[-3:]),
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
            )
482

483
484
485
    def test_hidden_states_output(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Joseph Liu's avatar
Joseph Liu committed
486
        def check_hidden_states_output(config, inputs_dict, model_class):
487
            model = model_class(config)
488
            outputs = model(self._prepare_for_class(inputs_dict, model_class))
489
490
491
            hidden_states = [t.numpy() for t in outputs[-1]]
            self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
            self.assertListEqual(
Lysandre's avatar
Lysandre committed
492
493
                list(hidden_states[0].shape[-2:]),
                [self.model_tester.seq_length, self.model_tester.hidden_size],
494
            )
495

Joseph Liu's avatar
Joseph Liu committed
496
497
498
499
500
501
502
503
        for model_class in self.all_model_classes:
            inputs_dict["output_hidden_states"] = True
            check_hidden_states_output(config, inputs_dict, model_class)

            del inputs_dict["output_hidden_states"]
            config.output_hidden_states = True
            check_hidden_states_output(config, inputs_dict, model_class)

504
505
506
507
508
    def test_model_common_attributes(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
509
            assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
510
511
512
513
514
515
516
517
            x = model.get_output_embeddings()
            assert x is None or isinstance(x, tf.keras.layers.Layer)

    def test_determinism(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
518
            first, second = (
519
520
                model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
                model(self._prepare_for_class(inputs_dict, model_class), training=False)[0],
521
            )
522
523
524
525
526
527
528
            out_1 = first.numpy()
            out_2 = second.numpy()
            out_1 = out_1[~np.isnan(out_1)]
            out_2 = out_2[~np.isnan(out_2)]
            max_diff = np.amax(np.abs(out_1 - out_2))
            self.assertLessEqual(max_diff, 1e-5)

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def test_model_outputs_equivalence(self):

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
            tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
            dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()

            def recursive_check(tuple_object, dict_object):
                if isinstance(tuple_object, (List, Tuple)):
                    for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
                        recursive_check(tuple_iterable_value, dict_iterable_value)
                elif tuple_object is None:
                    return
                else:
                    self.assertTrue(
                        all(tf.equal(tuple_object, dict_object)),
                        msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
                    )

                recursive_check(tuple_output, dict_output)

        for model_class in self.all_model_classes:
            model = model_class(config)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs)

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})

            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
            check_equivalence(
                model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
            )

584
585
586
587
588
589
590
    def _get_embeds(self, wte, input_ids):
        # ^^ In our TF models, the input_embeddings can take slightly different forms,
        # so we try a few of them.
        # We used to fall back to just synthetically creating a dummy tensor of ones:
        try:
            x = wte(input_ids, mode="embedding")
        except Exception:
thomwolf's avatar
thomwolf committed
591
            try:
592
                x = wte([input_ids], mode="embedding")
593
            except Exception:
thomwolf's avatar
thomwolf committed
594
                try:
595
                    x = wte([input_ids, None, None, None], mode="embedding")
596
                except Exception:
597
                    if hasattr(self.model_tester, "embedding_size"):
Lysandre's avatar
Lysandre committed
598
599
600
601
                        x = tf.ones(
                            input_ids.shape + [self.model_tester.embedding_size],
                            dtype=tf.dtypes.float32,
                        )
602
                    else:
Lysandre's avatar
Lysandre committed
603
604
605
606
                        x = tf.ones(
                            input_ids.shape + [self.model_tester.hidden_size],
                            dtype=tf.dtypes.float32,
                        )
607
608
609
610
611
612
613
614
        return x

    def test_inputs_embeds(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)

615
616
617
618
619
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
            if not self.is_encoder_decoder:
                input_ids = inputs["input_ids"]
                del inputs["input_ids"]
            else:
620
                encoder_input_ids = inputs["input_ids"]
621
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
622
                del inputs["input_ids"]
623
624
                inputs.pop("decoder_input_ids", None)

625
            wte = model.get_input_embeddings()
thomwolf's avatar
thomwolf committed
626
            if not self.is_encoder_decoder:
627
                inputs["inputs_embeds"] = self._get_embeds(wte, input_ids)
thomwolf's avatar
thomwolf committed
628
            else:
629
630
                inputs["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
                inputs["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
631

632
            model(inputs)
633

634
635
636
637
638
639
640
641
642
643
644
645
646
    def test_resize_token_embeddings(self):
        if not self.test_resize_embeddings:
            return
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        INPUT_SHAPE = [1, 10, config.hidden_size]
        for model_class in self.all_model_classes:
            for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
                # build the embeddings
                model = model_class(config=config)
                emb_old = model.get_input_embeddings()
                emb_old.build(INPUT_SHAPE)
                # reshape the embeddings
                new_embeddings = model._get_resized_embeddings(emb_old, size)
Julien Chaumond's avatar
Julien Chaumond committed
647
                # # check that the resized embeddings size matches the desired size.
648
649
650
651
652
653
654
655
656
657
                assert_size = size if size is not None else config.vocab_size
                self.assertEqual(new_embeddings.shape[0], assert_size)
                # check that weights remain the same after resizing
                emd_old_weights = model._get_word_embeddings(emb_old)
                models_equal = True
                for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
                    if np.sum(abs(p1 - p2)) > 0:
                        models_equal = False
                self.assertTrue(models_equal)

658
    def test_lm_head_model_random_no_beam_search_generate(self):
659
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
660
        input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
661

662
        # iterate over all generative models
663
664
665
666
        for model_class in self.all_generative_model_classes:
            model = model_class(config)

            if config.bos_token_id is None:
667
                # if bos token id is not defined mobel needs input_ids
668
                with self.assertRaises(AssertionError):
669
                    model.generate(do_sample=True, max_length=5)
670
                # num_return_sequences = 1
671
                self._check_generated_ids(model.generate(input_ids, do_sample=True))
672
            else:
673
                # num_return_sequences = 1
674
                self._check_generated_ids(model.generate(do_sample=True, max_length=5))
675
676

            with self.assertRaises(AssertionError):
677
                # generating multiple sequences when no beam search generation
678
679
680
                # is not allowed as it would always generate the same sequences
                model.generate(input_ids, do_sample=False, num_return_sequences=2)

681
682
            # num_return_sequences > 1, sample
            self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
683
684

            # check bad words tokens language generation
685
686
            # create list of 1-seq bad token and list of 2-seq of bad tokens
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
687
            output_tokens = model.generate(
688
                input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
689
            )
690
            # only count generated tokens
691
692
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
693

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    def test_lm_head_model_random_beam_search_generate(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]

        for model_class in self.all_generative_model_classes:
            model = model_class(config)

            if config.bos_token_id is None:
                # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
                self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
            else:
                # num_return_sequences = 1
                self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))

            with self.assertRaises(AssertionError):
                # generating more sequences than having beams leads is not possible
                model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)

            # num_return_sequences > 1, sample
Lysandre's avatar
Lysandre committed
713
714
715
716
717
718
719
720
            self._check_generated_ids(
                model.generate(
                    input_ids,
                    do_sample=True,
                    num_beams=2,
                    num_return_sequences=2,
                )
            )
721
722
723
724
725
726
            # num_return_sequences > 1, greedy
            self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))

            # check bad words tokens language generation
            # create list of 1-seq bad token and list of 2-seq of bad tokens
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
727
            output_tokens = model.generate(
728
                input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
729
            )
730
            # only count generated tokens
731
732
733
            generated_ids = output_tokens[:, input_ids.shape[-1] :]
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))

734
735
736
737
738
739
740
741
742
743
    def test_loss_computation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            model = model_class(config)
            if getattr(model, "compute_loss", None):
                # The number of elements in the loss should be the same as the number of elements in the label
                prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
                added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
                loss_size = tf.size(added_label)

744
745
746
747
748
                if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
                    # if loss is causal lm loss, labels are shift, so that one label per batch
                    # is cut
                    loss_size = loss_size - self.model_tester.batch_size

749
750
751
                # Test that model correctly compute the loss with kwargs
                prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
                input_ids = prepared_for_class.pop("input_ids")
752

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
                loss = model(input_ids, **prepared_for_class)[0]
                self.assertEqual(loss.shape, [loss_size])

                # Test that model correctly compute the loss with a dict
                prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
                loss = model(prepared_for_class)[0]
                self.assertEqual(loss.shape, [loss_size])

                # Test that model correctly compute the loss with a tuple
                prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)

                # Get keys that were added with the _prepare_for_class function
                label_keys = prepared_for_class.keys() - inputs_dict.keys()
                signature = inspect.getfullargspec(model.call)[0]

                # Create a dictionary holding the location of the tensors in the tuple
                tuple_index_mapping = {1: "input_ids"}
                for label_key in label_keys:
                    label_key_index = signature.index(label_key)
                    tuple_index_mapping[label_key_index] = label_key
                sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())

                # Initialize a list with None, update the values and convert to a tuple
                list_input = [None] * sorted_tuple_index_mapping[-1][0]
                for index, value in sorted_tuple_index_mapping:
                    list_input[index - 1] = prepared_for_class[value]
                tuple_input = tuple(list_input)

                # Send to model
                loss = model(tuple_input)[0]
                self.assertEqual(loss.shape, [loss_size])

785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    def _generate_random_bad_tokens(self, num_bad_tokens, model):
        # special tokens cannot be bad tokens
        special_tokens = []
        if model.config.bos_token_id is not None:
            special_tokens.append(model.config.bos_token_id)
        if model.config.pad_token_id is not None:
            special_tokens.append(model.config.pad_token_id)
        if model.config.eos_token_id is not None:
            special_tokens.append(model.config.eos_token_id)

        # create random bad tokens that are not special tokens
        bad_tokens = []
        while len(bad_tokens) < num_bad_tokens:
            token = tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), 0).numpy()[0]
            if token not in special_tokens:
                bad_tokens.append(token)
        return bad_tokens

803
    def _check_generated_ids(self, output_ids):
804
805
806
807
        for token_id in output_ids[0].numpy().tolist():
            self.assertGreaterEqual(token_id, 0)
            self.assertLess(token_id, self.model_tester.vocab_size)

808
809
810
811
812
813
814
815
816
817
818
819
    def _check_match_tokens(self, generated_ids, bad_words_ids):
        # for all bad word tokens
        for bad_word_ids in bad_words_ids:
            # for all slices in batch
            for generated_ids_slice in generated_ids:
                # for all word idx
                for i in range(len(bad_word_ids), len(generated_ids_slice)):
                    # if tokens match
                    if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
                        return True
        return False

thomwolf's avatar
thomwolf committed
820

thomwolf's avatar
thomwolf committed
821
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
thomwolf's avatar
thomwolf committed
822
823
824
825
826
827
828
829
830
831
832
833
    """Creates a random int32 tensor of the shape within the vocab size."""
    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

834
    output = tf.constant(values, shape=shape, dtype=dtype if dtype is not None else tf.int32)
thomwolf's avatar
thomwolf committed
835
836

    return output
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914


@require_tf
class UtilsFunctionsTest(unittest.TestCase):

    # tests whether the top_k_top_p_filtering function behaves as expected
    def test_top_k_top_p_filtering(self):
        logits = tf.convert_to_tensor(
            [
                [
                    8.2220991,  # 3rd highest value; idx. 0
                    -0.5620044,
                    5.23229752,
                    4.0386393,
                    -6.8798378,
                    -0.54785802,
                    -3.2012153,
                    2.92777176,
                    1.88171953,
                    7.35341276,  # 5th highest value; idx. 9
                    8.43207833,  # 2nd highest value; idx. 10
                    -9.85711836,
                    -5.96209236,
                    -1.13039161,
                    -7.1115294,
                    -0.8369633,
                    -5.3186408,
                    7.06427407,
                    0.81369344,
                    -0.82023817,
                    -5.9179796,
                    0.58813443,
                    -6.99778438,
                    4.71551189,
                    -0.18771637,
                    7.44020759,  # 4th highest value; idx. 25
                    9.38450987,  # 1st highest value; idx. 26
                    2.12662941,
                    -9.32562038,
                    2.35652522,
                ],  # cummulative prob of 5 highest values <= 0.6
                [
                    0.58425518,
                    4.53139238,
                    -5.57510464,
                    -6.28030699,
                    -7.19529503,
                    -4.02122551,
                    1.39337037,
                    -6.06707057,
                    1.59480517,
                    -9.643119,
                    0.03907799,
                    0.67231762,
                    -8.88206726,
                    6.27115922,  # 4th highest value; idx. 13
                    2.28520723,
                    4.82767506,
                    4.30421368,
                    8.8275313,  # 2nd highest value; idx. 17
                    5.44029958,  # 5th highest value; idx. 18
                    -4.4735794,
                    7.38579536,  # 3rd highest value; idx. 20
                    -2.91051663,
                    2.61946077,
                    -2.5674762,
                    -9.48959302,
                    -4.02922645,
                    -1.35416918,
                    9.67702323,  # 1st highest value; idx. 27
                    -5.89478553,
                    1.85370467,
                ],  # cummulative prob of 5 highest values <= 0.6
            ],
            dtype=tf.float32,
        )

        non_inf_expected_idx = tf.convert_to_tensor(
Lysandre's avatar
Lysandre committed
915
916
            [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
            dtype=tf.int32,
917
918
919
920
921
922
923
924
925
926
927
        )  # expected non filtered idx as noted above

        non_inf_expected_output = tf.convert_to_tensor(
            [8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
            dtype=tf.float32,
        )  # expected non filtered values as noted above

        output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)

        non_inf_output = output[output != -float("inf")]
        non_inf_idx = tf.cast(
Lysandre's avatar
Lysandre committed
928
929
            tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))),
            dtype=tf.int32,
930
931
932
933
        )

        tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
        tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)