modeling_tf_utils.py 44.4 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
17
import functools
thomwolf's avatar
thomwolf committed
18
import os
Julien Plu's avatar
Julien Plu committed
19
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
20
from typing import Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
21

Aymeric Augustin's avatar
Aymeric Augustin committed
22
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
23
import numpy as np
thomwolf's avatar
thomwolf committed
24
import tensorflow as tf
thomwolf's avatar
thomwolf committed
25
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
26
27

from .configuration_utils import PretrainedConfig
28
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
29
from .generation_tf_utils import TFGenerationMixin
30
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
Lysandre Debut's avatar
Lysandre Debut committed
31
from .utils import logging
thomwolf's avatar
thomwolf committed
32

Aymeric Augustin's avatar
Aymeric Augustin committed
33

Lysandre Debut's avatar
Lysandre Debut committed
34
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
35

36

37
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
38
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
39
    A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
40
41
42
43
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
46
47
48
49
50
51
        Get the number of (optionally, trainable) parameters in the model.

        Args:
            only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to return only the number of trainable parameters

        Returns:
            :obj:`int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
52
53
54
55
56
57
58
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


59
def keras_serializable(cls):
60
61
62
63
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
       serialization time.
    2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
       time) and convert it to a config object for the actual layer initializer.
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
       not need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.

    Args:
        cls (a :obj:`tf.keras.layers.Layers subclass`):
            Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
            its initializer.

    Returns:
        The same class object, with modifications for Keras deserialization.
79
    """
80
    initializer = cls.__init__
81

82
83
84
85
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

86
    @functools.wraps(initializer)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def wrapped_init(self, *args, **kwargs):
        transformers_config = kwargs.pop("transformers_config", None)
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
        if config is not None and transformers_config is not None:
            raise ValueError("Must pass either `config` or `transformers_config`, not both")
        elif config is not None:
            # normal layer construction, call with unchanged args (config is already in there)
            initializer(self, *args, **kwargs)
        elif transformers_config is not None:
            # Keras deserialization, convert dict to config
            config = config_class.from_dict(transformers_config)
            initializer(self, config, *args, **kwargs)
        else:
            raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
101
        self._transformers_config = config
Julien Plu's avatar
Julien Plu committed
102
        self._kwargs = kwargs
103

104
105
106
107
108
109
110
111
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
112
            cfg["transformers_config"] = self._transformers_config.to_dict()
Julien Plu's avatar
Julien Plu committed
113
            cfg.update(self._kwargs)
114
115
116
117
            return cfg

        cls.get_config = get_config

118
    cls._keras_serializable = True
119
120
121
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
122
123


124
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
125
126
127
128
129
130
131
132
133
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

134
135
136
137
138
139
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        # make sure only labels that are not equal to -100
        # are taken into account as loss
140
        active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
141
142
143
144
145
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
        return loss_fn(labels, reduced_logits)


Julien Plu's avatar
Julien Plu committed
146
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
149
150
    """
    Loss function suitable for quetion answering.
    """

Julien Plu's avatar
Julien Plu committed
151
152
153
154
155
156
157
158
159
160
161
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
168
169
170
    """
    Loss function suitable for token classification.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

Julien Plu's avatar
Julien Plu committed
171
172
173
174
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
175
176
        # make sure only labels that are not equal to -100
        # are taken into account as loss
177
        if tf.math.reduce_any(labels == -1):
Julien Plu's avatar
Julien Plu committed
178
179
180
181
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
Julien Plu's avatar
Julien Plu committed
182
183
184
185
186
187
188
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
192
    """
    Loss function suitable for sequence classification.
    """

Julien Plu's avatar
Julien Plu committed
193
    def compute_loss(self, labels, logits):
194
        if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
Julien Plu's avatar
Julien Plu committed
195
196
197
198
199
200
201
202
203
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
    """Loss function suitable for multiple choice tasks."""


class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
210
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
211

Lysandre's avatar
Lysandre committed
212
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
213

Lysandre's avatar
Lysandre committed
214
215
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """
Julien Plu's avatar
Julien Plu committed
216
217


218
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
219
220
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
221

222
223
    :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
224

225
226
        * resize the input embeddings,
        * prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
227

228
229
230
231
232
    Class attributes (overridden by derived classes):
        - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
          :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
        - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
          derived classes of the same architecture adding modules on top of the base model.
thomwolf's avatar
thomwolf committed
233
234
235
236
    """
    config_class = None
    base_model_prefix = ""

237
    @property
238
239
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
240
241
242
243
        Dummy inputs to build the network.

        Returns:
            :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
244
        """
245
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
246
247

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
248
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
249
250
251
252
253
254
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
255
256
                )
            )
thomwolf's avatar
thomwolf committed
257
258
259
        # Save config in model
        self.config = config

260
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
261
262
263
264
        """
        Returns the model's input embeddings.

        Returns:
265
            :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
266
267
268
269
270
271
272
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

273
274
    def set_input_embeddings(self, value):
        """
275
        Set model's input embeddings.
276
277
278
279
280
281
282
283
284
285
286

        Args:
            value (:obj:`tf.keras.layers.Layer`):
                A module mapping vocabulary to hidden states.
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError

287
    def get_output_embeddings(self) -> tf.keras.layers.Layer:
288
289
290
291
        """
        Returns the model's output embeddings.

        Returns:
292
            :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
293
294
295
        """
        return None  # Overwrite for models with output embeddings

296
297
298
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
        Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
299

300
        Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
301

302
303
304
305
306
307
308
309
310
        Arguments:
            new_num_tokens (:obj:`int`, `optional`):
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
                just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
                anything.

        Return:
            :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        """
        model_embeds = self._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds

        return model_embeds

    def _resize_token_embeddings(self, new_num_tokens):
        # get_input_embeddings and set_input_embeddings need to be implemented in base layer.
        base_model = getattr(self, self.base_model_prefix, self)
        old_embeddings = base_model.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        base_model.set_input_embeddings(new_embeddings)
        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens
        return base_model.get_input_embeddings()

    def _get_word_embeddings(self, embeddings):
        if hasattr(embeddings, "word_embeddings"):
            # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
            return embeddings.word_embeddings
        elif hasattr(embeddings, "weight"):
            # TFSharedEmbeddings
            return embeddings.weight
        else:
            raise ValueError("word embedding is not defined.")

339
340
341
342
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
343
344

        Args:
345
346
347
            old_embeddings (:obj:`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (:obj:`int`, `optional`):
thomwolf's avatar
thomwolf committed
348
                New number of tokens in the embedding matrix.
349
350
351
352
353
354
355
356

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
                :obj:`tf.Variable`` module of the model wihtout doing anything.

        Return:
            :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
            :obj:`new_num_tokens` is :obj:`None`
thomwolf's avatar
thomwolf committed
357
        """
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        word_embeddings = self._get_word_embeddings(old_embeddings)
        if new_num_tokens is None:
            return word_embeddings
        old_num_tokens, old_embedding_dim = word_embeddings.shape
        if old_num_tokens == new_num_tokens:
            return word_embeddings

        # initialize new embeddings
        # todo: initializer range is not always passed in config.
        init_range = getattr(self.config, "initializer_range", 0.02)
        new_embeddings = self.add_weight(
            "weight",
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_weights = new_embeddings.numpy()
thomwolf's avatar
thomwolf committed
375

376
377
378
379
        # Copy token embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
        new_embeddings.assign(init_weights)
thomwolf's avatar
thomwolf committed
380

381
        return new_embeddings
thomwolf's avatar
thomwolf committed
382
383

    def prune_heads(self, heads_to_prune):
384
385
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
386

387
388
389
390
391
        Arguments:
            heads_to_prune (:obj:`Dict[int, List[int]]`):
                Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
                of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
                prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
392
393
394
395
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
396
397
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
Sylvain Gugger's avatar
Sylvain Gugger committed
398
        :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
399
400
401
402

        Arguments:
            save_directory (:obj:`str`):
                Directory to which to save. Will be created if it doesn't exist.
thomwolf's avatar
thomwolf committed
403
        """
404
405
406
407
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
            return
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
408
409
410
411
412
413
414

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
415
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
416
417
418

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
419
420
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
421

422
423
424
        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
425

426
427
        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.
thomwolf's avatar
thomwolf committed
428
429

        Parameters:
430
431
432
433
434
435
436
437
438
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
                      ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
                      ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
Sylvain Gugger's avatar
Sylvain Gugger committed
439
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `shortcut name` string of a
                      pretrained model).
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
                      by suppling the save directory.
                    - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g.,
                :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
                request.
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
                messages.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
            use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
Sylvain Gugger's avatar
Sylvain Gugger committed
486
                our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
487
488
489
490
            mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
                you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
                refer to the mirror site for more information.
491
492
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
493
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
494
495
496
497
498
499
500
501
502
503
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
504
505
506

        Examples::

507
508
509
510
511
512
            from transformers import BertConfig, TFBertModel
            # Download model and configuration from S3 and cache.
            model = TFBertModel.from_pretrained('bert-base-uncased')
            # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            model = TFBertModel.from_pretrained('./test/saved_model/')
            # Update configuration during loading.
513
514
            model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
            assert model.config.output_attentions == True
515
516
517
            # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
518
519

        """
520
521
522
523
524
525
526
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
527
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
528
        use_cdn = kwargs.pop("use_cdn", True)
529
        mirror = kwargs.pop("mirror", None)
thomwolf's avatar
thomwolf committed
530

531
532
533
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
534
            config, model_kwargs = cls.config_class.from_pretrained(
535
536
537
538
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
539
                force_download=force_download,
540
                resume_download=resume_download,
541
542
                proxies=proxies,
                local_files_only=local_files_only,
543
                **kwargs,
thomwolf's avatar
thomwolf committed
544
545
546
547
548
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
549
        if pretrained_model_name_or_path is not None:
550
            if os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
551
552
553
554
555
556
557
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
558
559
560
561
562
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
563
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
564
                archive_file = pretrained_model_name_or_path
565
566
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
567
            else:
thomwolf's avatar
thomwolf committed
568
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
569
570
571
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
                    use_cdn=use_cdn,
572
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
573
                )
thomwolf's avatar
thomwolf committed
574
575

            try:
576
                # Load from URL or cache if already cached
577
578
579
580
581
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
582
583
                    resume_download=resume_download,
                    local_files_only=local_files_only,
584
                )
585
586
587
588
589
590
591
592
593
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
thomwolf's avatar
thomwolf committed
594
595
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
596
            else:
597
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
598
        else:
thomwolf's avatar
thomwolf committed
599
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
600
601
602
603
604
605

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
606
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
607

608
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
609

thomwolf's avatar
thomwolf committed
610
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
611
612
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
613
614
615
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
616
617
618
619
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
620

621
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
622

thomwolf's avatar
thomwolf committed
623
        # Check if the models are the same to output loading informations
624
625
626
627
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
thomwolf's avatar
thomwolf committed
628
629
630
631
632
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

633
634
635
636
637
638
639
640
641
642
643
        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
thomwolf's avatar
thomwolf committed
644
        if len(missing_keys) > 0:
645
646
647
648
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
649
            )
650
651
652
        else:
            logger.warning(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
653
                f"If your task is similar to the task the model of the checkpoint was trained on, "
654
                f"you can already use {model.__class__.__name__} for predictions without further training."
655
            )
thomwolf's avatar
thomwolf committed
656
        if len(error_msgs) > 0:
657
658
659
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
thomwolf's avatar
thomwolf committed
660
        if output_loading_info:
661
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
thomwolf's avatar
thomwolf committed
662
663
            return model, loading_info

thomwolf's avatar
thomwolf committed
664
        return model
thomwolf's avatar
WIP  
thomwolf committed
665

666

thomwolf's avatar
WIP  
thomwolf committed
667
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (:obj:`int`):
            The number of output features.
        nx (:obj:`int`):
            The number of input features.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

thomwolf's avatar
thomwolf committed
684
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
685
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
686
        self.nf = nf
thomwolf's avatar
thomwolf committed
687
        self.nx = nx
thomwolf's avatar
thomwolf committed
688
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
689
690
691

    def build(self, input_shape):
        self.weight = self.add_weight(
692
693
694
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
695

thomwolf's avatar
WIP  
thomwolf committed
696
    def call(self, x):
thomwolf's avatar
thomwolf committed
697
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
698

thomwolf's avatar
thomwolf committed
699
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
700
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
701
702

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
703

thomwolf's avatar
WIP  
thomwolf committed
704
        return x
thomwolf's avatar
thomwolf committed
705
706


thomwolf's avatar
thomwolf committed
707
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
708
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
709
    Construct shared token embeddings.
710

Sylvain Gugger's avatar
Sylvain Gugger committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing
    language modeling.

    Args:
        vocab_size (:obj:`int`):
            The size of the vocabular, e.g., the number of unique tokens.
        hidden_size (:obj:`int`):
            The size of the embedding vectors.
        initializer_range (:obj:`float`, `optional`):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            :math:`1/\sqrt{hidden\_size}`.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
727
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
728
729
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
730
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
731
732

    def build(self, input_shape):
733
        """Build shared token embedding layer
thomwolf's avatar
thomwolf committed
734
735
736
737
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
738
739
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
740
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
741

Julien Plu's avatar
Julien Plu committed
742
743
744
745
746
747
748
749
750
751
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
752
753
754
755
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
756
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
757
758
759
760
761
762
763
764
            inputs (:obj:`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.

                In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
            mode (:obj:`str`, defaults to :obj:`"embedding"`):
               A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
               should be used as an embedding layer, the second one that the layer should be used as a linear decoder.

thomwolf's avatar
thomwolf committed
765
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
766
767
768
769
770
771
            :obj:`tf.Tensor`:
            In embedding mode, the output is a float32  embedding tensor, with shape
            :obj:`[batch_size, length, embedding_size]`.

            In linear mode, the ouput is a float32 with shape :obj:`[batch_size, length, vocab_size]`.

thomwolf's avatar
thomwolf committed
772
        Raises:
Sylvain Gugger's avatar
Sylvain Gugger committed
773
            ValueError: if :obj:`mode` is not valid.
774

Sylvain Gugger's avatar
Sylvain Gugger committed
775
776
        Shared weights logic is adapted from
        `here <https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
thomwolf's avatar
thomwolf committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
791
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
792

Julien Plu's avatar
Julien Plu committed
793
794
795
796
797
798
799
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
800
801
802
803
804
805
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
806
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
807
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
    Compute a single vector summary of a sequence hidden states.

    Args:
        config (:class:`~transformers.PretrainedConfig`):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the
            actual config class of your model for the default values it uses):

            - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

                - :obj:`"last"` -- Take the last token hidden state (like XLNet)
                - :obj:`"first"` -- Take the first token hidden state (like Bert)
                - :obj:`"mean"` -- Take the mean of all tokens hidden states
                - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - :obj:`"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
              :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
            - **summary_activation**  (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
              output, another string or :obj:`None` will add no activation.
            - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
              activation.
            - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
              activation.

        initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
836
    """
837

Sylvain Gugger's avatar
Sylvain Gugger committed
838
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
839
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
840

841
842
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
843
844
845
846
847
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

848
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
849
        if self.has_summary:
850
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
851
852
853
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
854
855
856
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
857

858
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
859
        if self.has_activation:
860
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
861

862
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
863
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
864
865
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

866
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
867
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
868
869
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
870
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
871
872
873
874
875
876
877
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
878
            hidden_states = inputs.get("hidden_states")
879
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
880

881
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
882
            output = hidden_states[:, -1]
883
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
884
            output = hidden_states[:, 0]
885
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
886
            output = tf.reduce_mean(hidden_states, axis=1)
887
        elif self.summary_type == "cls_index":
888
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
889
            if cls_index is None:
890
891
892
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
893
894
895
896
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
897
898
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
899
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
900
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
901
902
903
904
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
905
906
            raise NotImplementedError

907
908
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
909

910
        if self.has_summary:
911
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
912

913
        if self.has_activation:
thomwolf's avatar
thomwolf committed
914
915
            output = self.activation(output)

916
917
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
918
919
920

        return output

921

Sylvain Gugger's avatar
Sylvain Gugger committed
922
923
924
925
926
927
928
929
930
931
def shape_list(x: tf.Tensor) -> List[int]:
    """
    Deal with dynamic shape in tensorflow cleanly.

    Args:
        x (:obj:`tf.Tensor`): The tensor we want the shape of.

    Returns:
        :obj:`List[int]`: The shape of the tensor as a list.
    """
thomwolf's avatar
thomwolf committed
932
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
933
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
934
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
935

936

Sylvain Gugger's avatar
Sylvain Gugger committed
937
938
939
940
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
    Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.

Julien Chaumond's avatar
Julien Chaumond committed
941
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
942
943
        initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.

Julien Chaumond's avatar
Julien Chaumond committed
944
    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
945
        :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
946
947
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
948
949


Sylvain Gugger's avatar
Sylvain Gugger committed
950
951
952
953
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
    """
    Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
    cast the bool argumnets (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
954
955

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
956
957
958
959
960
961
962
        bool_variable (:obj:`Union[tf.Tensor, bool]`):
            The variable to convert to a boolean.
        default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
            The default value to use in case the tensor has no numpy attribute.

    Returns:
        :obj:`bool`: The converted value.
963
964
965
966
967
968
969
970
971
972
    """
    # if bool variable is tensor and has numpy value
    if tf.is_tensor(bool_variable):
        if hasattr(bool_variable, "numpy"):
            return bool(bool_variable.numpy())
        elif default_tensor_to_true:
            return True

    # else variable is bool
    return bool_variable