modeling_tf_utils.py 49.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
17
import functools
thomwolf's avatar
thomwolf committed
18
import os
19
import re
Julien Plu's avatar
Julien Plu committed
20
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from typing import Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
22

Aymeric Augustin's avatar
Aymeric Augustin committed
23
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
24
import numpy as np
thomwolf's avatar
thomwolf committed
25
import tensorflow as tf
Julien Plu's avatar
Julien Plu committed
26
from tensorflow.python.keras import backend as K
thomwolf's avatar
thomwolf committed
27
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
28
29

from .configuration_utils import PretrainedConfig
30
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
31
from .generation_tf_utils import TFGenerationMixin
Lysandre Debut's avatar
Lysandre Debut committed
32
from .utils import logging
thomwolf's avatar
thomwolf committed
33

Aymeric Augustin's avatar
Aymeric Augustin committed
34

Lysandre Debut's avatar
Lysandre Debut committed
35
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
36

37

38
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
39
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
40
    A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
45
46
47
48
49
50
51
52
        Get the number of (optionally, trainable) parameters in the model.

        Args:
            only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to return only the number of trainable parameters

        Returns:
            :obj:`int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
53
54
55
56
57
58
59
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


60
def keras_serializable(cls):
61
62
63
64
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
65
66
67
68
69

    1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
       serialization time.
    2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
       time) and convert it to a config object for the actual layer initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
       need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
76
77
78
79

    Args:
        cls (a :obj:`tf.keras.layers.Layers subclass`):
            Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
            its initializer.

    Returns:
        The same class object, with modifications for Keras deserialization.
80
    """
81
    initializer = cls.__init__
82

83
84
85
86
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

87
    @functools.wraps(initializer)
88
    def wrapped_init(self, *args, **kwargs):
89
90
91
92
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
93
            initializer(self, config, *args, **kwargs)
94
95
96
97
98
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
99
        else:
100
101
102
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
Julien Plu's avatar
Julien Plu committed
103
        self._kwargs = kwargs
104

105
106
107
108
109
110
111
112
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
113
            cfg["config"] = self._config.to_dict()
Julien Plu's avatar
Julien Plu committed
114
            cfg.update(self._kwargs)
115
116
117
118
            return cfg

        cls.get_config = get_config

119
    cls._keras_serializable = True
120
121
122
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
123
124


125
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
126
127
128
129
130
131
132
133
134
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

135
136
137
138
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
139
        # make sure only labels that are not equal to -100 do not affect loss
140
        active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
141
142
143
144
145
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
        return loss_fn(labels, reduced_logits)


Julien Plu's avatar
Julien Plu committed
146
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
147
    """
148
    Loss function suitable for question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
149
150
    """

Julien Plu's avatar
Julien Plu committed
151
152
153
154
155
156
157
158
159
160
161
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
168
169
170
    """
    Loss function suitable for token classification.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

Julien Plu's avatar
Julien Plu committed
171
172
173
174
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
175
176
        # make sure only labels that are not equal to -100
        # are taken into account as loss
177
        if tf.math.reduce_any(labels == -1):
Julien Plu's avatar
Julien Plu committed
178
179
180
181
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
Julien Plu's avatar
Julien Plu committed
182
183
184
185
186
187
188
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
192
    """
    Loss function suitable for sequence classification.
    """

Julien Plu's avatar
Julien Plu committed
193
    def compute_loss(self, labels, logits):
194
        if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
Julien Plu's avatar
Julien Plu committed
195
196
197
198
199
200
201
202
203
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
    """Loss function suitable for multiple choice tasks."""


class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
210
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
211

Lysandre's avatar
Lysandre committed
212
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
213

Lysandre's avatar
Lysandre committed
214
215
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """
Julien Plu's avatar
Julien Plu committed
216
217


Julien Plu's avatar
Julien Plu committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
    """
    Detect missing and unexpected layers.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
    missing_layers = []
    unexpected_layers = []

    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_layers = list(model_layer_names - saved_layer_names)
        unexpected_layers = list(saved_layer_names - model_layer_names)

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                saved_weight_names_set = set(
                    "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
                )
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                symbolic_weights_names = set(
                    "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
                )
                missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
                unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

    return missing_layers, unexpected_layers


def load_tf_weights(model, resolved_archive_file):
    """
    Load the TF weights from a H5 file.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.
    """
    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        weight_value_tuples = []

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                saved_weight_names_values = {}

                for weight_name in saved_weight_names:
                    name = "/".join(weight_name.split("/")[1:])
                    saved_weight_names_values[name] = np.asarray(g[weight_name])

                for symbolic_weight in symbolic_weights:
                    splited_layers = symbolic_weight.name.split("/")[1:]
                    symbolic_weight_name = "/".join(splited_layers)

                    if symbolic_weight_name in saved_weight_names_values:
                        saved_weight_value = saved_weight_names_values[symbolic_weight_name]

                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
                            except AssertionError as e:
                                e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
                                raise e
                        else:
                            array = saved_weight_value

                        weight_value_tuples.append((symbolic_weight, array))

    K.batch_set_value(weight_value_tuples)


303
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
304
305
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
306

307
308
    :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
309

310
311
        * resize the input embeddings,
        * prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
312

313
    Class attributes (overridden by derived classes):
Sylvain Gugger's avatar
Sylvain Gugger committed
314

315
316
317
318
        - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
          :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
        - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
          derived classes of the same architecture adding modules on top of the base model.
Julien Plu's avatar
Julien Plu committed
319
320
        - **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
          from the model when loading the model weights (and avoid unnecessary warnings).
Sylvain Gugger's avatar
Sylvain Gugger committed
321
322
        - **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
          ignore from the weights when loading the model weights (and avoid unnecessary warnings).
thomwolf's avatar
thomwolf committed
323
324
325
    """
    config_class = None
    base_model_prefix = ""
326
    authorized_missing_keys = None
Julien Plu's avatar
Julien Plu committed
327
    authorized_unexpected_keys = None
thomwolf's avatar
thomwolf committed
328

329
    @property
330
331
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
332
333
334
335
        Dummy inputs to build the network.

        Returns:
            :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
336
        """
337
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
338
339

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
340
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
341
342
343
344
345
346
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
347
348
                )
            )
349
        # Save config and origin of the pretrained weights if given in model
thomwolf's avatar
thomwolf committed
350
        self.config = config
351
        self.name_or_path = config.name_or_path
thomwolf's avatar
thomwolf committed
352

353
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
354
355
356
357
        """
        Returns the model's input embeddings.

        Returns:
358
            :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
359
360
361
362
363
364
365
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

366
367
    def set_input_embeddings(self, value):
        """
368
        Set model's input embeddings.
369
370
371
372
373
374
375
376
377
378
379

        Args:
            value (:obj:`tf.keras.layers.Layer`):
                A module mapping vocabulary to hidden states.
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError

380
    def get_output_embeddings(self) -> tf.keras.layers.Layer:
381
382
383
384
        """
        Returns the model's output embeddings.

        Returns:
385
            :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
386
387
388
        """
        return None  # Overwrite for models with output embeddings

389
390
391
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
        Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
392

393
        Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
394

395
396
397
398
399
400
401
402
403
        Arguments:
            new_num_tokens (:obj:`int`, `optional`):
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
                just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
                anything.

        Return:
            :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        """
        model_embeds = self._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds

        return model_embeds

    def _resize_token_embeddings(self, new_num_tokens):
        # get_input_embeddings and set_input_embeddings need to be implemented in base layer.
        base_model = getattr(self, self.base_model_prefix, self)
        old_embeddings = base_model.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        base_model.set_input_embeddings(new_embeddings)
        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens
        return base_model.get_input_embeddings()

    def _get_word_embeddings(self, embeddings):
        if hasattr(embeddings, "word_embeddings"):
            # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
            return embeddings.word_embeddings
        elif hasattr(embeddings, "weight"):
            # TFSharedEmbeddings
            return embeddings.weight
        else:
            raise ValueError("word embedding is not defined.")

432
433
434
435
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
436
437

        Args:
438
439
440
            old_embeddings (:obj:`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (:obj:`int`, `optional`):
thomwolf's avatar
thomwolf committed
441
                New number of tokens in the embedding matrix.
442
443
444
445
446
447
448
449

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
                :obj:`tf.Variable`` module of the model wihtout doing anything.

        Return:
            :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
            :obj:`new_num_tokens` is :obj:`None`
thomwolf's avatar
thomwolf committed
450
        """
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        word_embeddings = self._get_word_embeddings(old_embeddings)
        if new_num_tokens is None:
            return word_embeddings
        old_num_tokens, old_embedding_dim = word_embeddings.shape
        if old_num_tokens == new_num_tokens:
            return word_embeddings

        # initialize new embeddings
        # todo: initializer range is not always passed in config.
        init_range = getattr(self.config, "initializer_range", 0.02)
        new_embeddings = self.add_weight(
            "weight",
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_weights = new_embeddings.numpy()
thomwolf's avatar
thomwolf committed
468

469
470
471
472
        # Copy token embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
        new_embeddings.assign(init_weights)
thomwolf's avatar
thomwolf committed
473

474
        return new_embeddings
thomwolf's avatar
thomwolf committed
475
476

    def prune_heads(self, heads_to_prune):
477
478
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
479

480
481
        Arguments:
            heads_to_prune (:obj:`Dict[int, List[int]]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
482
483
484
                Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list of
                heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads
                0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
485
486
487
488
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
489
490
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
Sylvain Gugger's avatar
Sylvain Gugger committed
491
        :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
492
493
494
495

        Arguments:
            save_directory (:obj:`str`):
                Directory to which to save. Will be created if it doesn't exist.
thomwolf's avatar
thomwolf committed
496
        """
497
498
499
500
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
            return
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
501
502
503
504
505
506
507

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
508
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
509
510
511

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
512
513
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
514

515
516
517
        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
518

519
520
        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.
thomwolf's avatar
thomwolf committed
521
522

        Parameters:
523
524
525
526
527
528
529
530
531
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
                      ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
                      ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
Sylvain Gugger's avatar
Sylvain Gugger committed
532
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `shortcut name` string of a
                      pretrained model).
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
553
554
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
555
556
557
558
559
560
561
562
563
564
565
566
567
568
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
569
570
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
571
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
572
                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
573
574
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
Julien Chaumond's avatar
Julien Chaumond committed
575
576
577
578
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
579
            mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Sylvain Gugger's avatar
Sylvain Gugger committed
580
581
582
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
                Please refer to the mirror site for more information.
583
584
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
585
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
586
587
588
589
590
591
592
593
594
595
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
596
597
598

        Examples::

599
600
601
602
603
604
605
606
607
608
609
            >>> from transformers import BertConfig, TFBertModel
            >>> # Download model and configuration from S3 and cache.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = TFBertModel.from_pretrained('./test/saved_model/')
            >>> # Update configuration during loading.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
            >>> assert model.config.output_attentions == True
            >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            >>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
610
611

        """
612
613
614
615
616
617
618
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
619
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
620
        revision = kwargs.pop("revision", None)
621
        mirror = kwargs.pop("mirror", None)
thomwolf's avatar
thomwolf committed
622

623
624
625
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
626
            config, model_kwargs = cls.config_class.from_pretrained(
627
628
629
630
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
631
                force_download=force_download,
632
                resume_download=resume_download,
633
634
                proxies=proxies,
                local_files_only=local_files_only,
Julien Chaumond's avatar
Julien Chaumond committed
635
                revision=revision,
636
                **kwargs,
thomwolf's avatar
thomwolf committed
637
638
639
640
641
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
642
        if pretrained_model_name_or_path is not None:
643
            if os.path.isdir(pretrained_model_name_or_path):
644
645
646
647
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint in priority if from_pt
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
thomwolf's avatar
thomwolf committed
648
649
650
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                else:
651
652
653
654
655
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
656
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
657
                archive_file = pretrained_model_name_or_path
658
659
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
660
            else:
thomwolf's avatar
thomwolf committed
661
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
662
663
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
Julien Chaumond's avatar
Julien Chaumond committed
664
                    revision=revision,
665
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
666
                )
thomwolf's avatar
thomwolf committed
667
668

            try:
669
                # Load from URL or cache if already cached
670
671
672
673
674
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
675
676
                    resume_download=resume_download,
                    local_files_only=local_files_only,
677
                )
Julien Chaumond's avatar
Julien Chaumond committed
678
679
            except EnvironmentError as err:
                logger.error(err)
680
681
682
683
684
685
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
thomwolf's avatar
thomwolf committed
686
687
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
688
            else:
689
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
690
        else:
thomwolf's avatar
thomwolf committed
691
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
692

693
694
        config.name_or_path = pretrained_model_name_or_path

thomwolf's avatar
thomwolf committed
695
696
697
698
        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
Julien Plu's avatar
Julien Plu committed
699
700
            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

thomwolf's avatar
thomwolf committed
701
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
702
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
703

704
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
705

thomwolf's avatar
thomwolf committed
706
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
707
708
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
709
        try:
Julien Plu's avatar
Julien Plu committed
710
            load_tf_weights(model, resolved_archive_file)
711
        except OSError:
712
713
714
715
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
716

717
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
718

Julien Plu's avatar
Julien Plu committed
719
        missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
thomwolf's avatar
thomwolf committed
720

721
722
723
724
        if cls.authorized_missing_keys is not None:
            for pat in cls.authorized_missing_keys:
                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

Julien Plu's avatar
Julien Plu committed
725
726
727
728
        if cls.authorized_unexpected_keys is not None:
            for pat in cls.authorized_unexpected_keys:
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

729
730
        if len(unexpected_keys) > 0:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
731
                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
732
733
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
734
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
735
736
737
738
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
Julien Plu's avatar
Julien Plu committed
739
740
            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")

thomwolf's avatar
thomwolf committed
741
        if len(missing_keys) > 0:
742
            logger.warning(
Julien Plu's avatar
Julien Plu committed
743
                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
744
745
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
746
            )
747
748
        else:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
749
                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
750
                f"If your task is similar to the task the model of the checkpoint was trained on, "
751
                f"you can already use {model.__class__.__name__} for predictions without further training."
752
            )
Julien Plu's avatar
Julien Plu committed
753

thomwolf's avatar
thomwolf committed
754
        if output_loading_info:
Julien Plu's avatar
Julien Plu committed
755
756
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}

thomwolf's avatar
thomwolf committed
757
758
            return model, loading_info

thomwolf's avatar
thomwolf committed
759
        return model
thomwolf's avatar
WIP  
thomwolf committed
760

761

thomwolf's avatar
WIP  
thomwolf committed
762
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (:obj:`int`):
            The number of output features.
        nx (:obj:`int`):
            The number of input features.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

thomwolf's avatar
thomwolf committed
779
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
780
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
781
        self.nf = nf
thomwolf's avatar
thomwolf committed
782
        self.nx = nx
thomwolf's avatar
thomwolf committed
783
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
784
785
786

    def build(self, input_shape):
        self.weight = self.add_weight(
787
788
789
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
790

thomwolf's avatar
WIP  
thomwolf committed
791
    def call(self, x):
thomwolf's avatar
thomwolf committed
792
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
793

thomwolf's avatar
thomwolf committed
794
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
795
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
796
797

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
798

thomwolf's avatar
WIP  
thomwolf committed
799
        return x
thomwolf's avatar
thomwolf committed
800
801


thomwolf's avatar
thomwolf committed
802
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
803
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
804
    Construct shared token embeddings.
805

Sylvain Gugger's avatar
Sylvain Gugger committed
806
807
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
808
809
810

    Args:
        vocab_size (:obj:`int`):
811
            The size of the vocabulary, e.g., the number of unique tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
812
813
814
815
816
817
818
819
820
821
        hidden_size (:obj:`int`):
            The size of the embedding vectors.
        initializer_range (:obj:`float`, `optional`):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            :math:`1/\sqrt{hidden\_size}`.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
822
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
823
824
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
825
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
826
827

    def build(self, input_shape):
Sylvain Gugger's avatar
Sylvain Gugger committed
828
829
830
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
thomwolf's avatar
thomwolf committed
831
832
        """
        self.weight = self.add_weight(
833
834
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
835
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
836

Julien Plu's avatar
Julien Plu committed
837
838
839
840
841
842
843
844
845
846
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
847
848
849
850
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
851
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
852
853
854
855
856
857
858
859
            inputs (:obj:`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.

                In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
            mode (:obj:`str`, defaults to :obj:`"embedding"`):
               A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
               should be used as an embedding layer, the second one that the layer should be used as a linear decoder.

thomwolf's avatar
thomwolf committed
860
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
861
            :obj:`tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape
Sylvain Gugger's avatar
Sylvain Gugger committed
862
863
            :obj:`[batch_size, length, embedding_size]`.

864
            In linear mode, the output is a float32 with shape :obj:`[batch_size, length, vocab_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
865

thomwolf's avatar
thomwolf committed
866
        Raises:
Sylvain Gugger's avatar
Sylvain Gugger committed
867
            ValueError: if :obj:`mode` is not valid.
868

Sylvain Gugger's avatar
Sylvain Gugger committed
869
870
        Shared weights logic is adapted from `here
        <https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
thomwolf's avatar
thomwolf committed
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
885
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
886

Julien Plu's avatar
Julien Plu committed
887
888
889
890
891
892
893
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
894
895
896
897
898
899
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
900
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
901
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
902
903
904
905
    Compute a single vector summary of a sequence hidden states.

    Args:
        config (:class:`~transformers.PretrainedConfig`):
Sylvain Gugger's avatar
Sylvain Gugger committed
906
907
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
Sylvain Gugger's avatar
Sylvain Gugger committed
908
909
910
911
912
913
914
915
916
917
918
919

            - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

                - :obj:`"last"` -- Take the last token hidden state (like XLNet)
                - :obj:`"first"` -- Take the first token hidden state (like Bert)
                - :obj:`"mean"` -- Take the mean of all tokens hidden states
                - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - :obj:`"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
              :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
Sylvain Gugger's avatar
Sylvain Gugger committed
920
            - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
Sylvain Gugger's avatar
Sylvain Gugger committed
921
922
923
924
925
926
927
928
929
              output, another string or :obj:`None` will add no activation.
            - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
              activation.
            - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
              activation.

        initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
930
    """
931

Sylvain Gugger's avatar
Sylvain Gugger committed
932
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
933
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
934

935
936
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
937
938
939
940
941
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

942
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
943
        if self.has_summary:
944
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
945
946
947
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
948
949
950
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
951

952
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
953
        if self.has_activation:
954
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
955

956
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
957
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
958
959
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

960
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
961
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
962
963
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
964
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
965
966
967
968
969
970
971
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
972
            hidden_states = inputs.get("hidden_states")
973
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
974

975
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
976
            output = hidden_states[:, -1]
977
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
978
            output = hidden_states[:, 0]
979
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
980
            output = tf.reduce_mean(hidden_states, axis=1)
981
        elif self.summary_type == "cls_index":
982
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
983
            if cls_index is None:
984
985
986
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
987
988
989
990
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
991
992
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
993
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
994
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
995
996
997
998
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
999
1000
            raise NotImplementedError

1001
1002
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1003

1004
        if self.has_summary:
1005
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1006

1007
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1008
1009
            output = self.activation(output)

1010
1011
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1012
1013
1014

        return output

1015

Sylvain Gugger's avatar
Sylvain Gugger committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
def shape_list(x: tf.Tensor) -> List[int]:
    """
    Deal with dynamic shape in tensorflow cleanly.

    Args:
        x (:obj:`tf.Tensor`): The tensor we want the shape of.

    Returns:
        :obj:`List[int]`: The shape of the tensor as a list.
    """
thomwolf's avatar
thomwolf committed
1026
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
1027
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
1028
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1029

1030

Sylvain Gugger's avatar
Sylvain Gugger committed
1031
1032
1033
1034
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
    Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.

Julien Chaumond's avatar
Julien Chaumond committed
1035
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1036
1037
        initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.

Julien Chaumond's avatar
Julien Chaumond committed
1038
    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1039
        :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
1040
1041
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
1042
1043


Sylvain Gugger's avatar
Sylvain Gugger committed
1044
1045
1046
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
    """
    Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
1047
    cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
1048
1049

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1050
1051
1052
1053
1054
1055
1056
        bool_variable (:obj:`Union[tf.Tensor, bool]`):
            The variable to convert to a boolean.
        default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
            The default value to use in case the tensor has no numpy attribute.

    Returns:
        :obj:`bool`: The converted value.
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    """
    # if bool variable is tensor and has numpy value
    if tf.is_tensor(bool_variable):
        if hasattr(bool_variable, "numpy"):
            return bool(bool_variable.numpy())
        elif default_tensor_to_true:
            return True

    # else variable is bool
    return bool_variable
Sam Shleifer's avatar
Sam Shleifer committed
1067
1068
1069
1070


class TFWrappedEmbeddings:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1071
1072
1073
    this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
    weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
    saving/storing the correct weights
Sam Shleifer's avatar
Sam Shleifer committed
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
    """

    def __init__(self, layer, abs_scope_name=None):
        self._layer = layer
        self._abs_scope_name = abs_scope_name

    def call(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer.call(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer.call(inputs, mode)

    def __call__(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer(inputs, mode)