modeling_tf_utils.py 49.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
17
import functools
thomwolf's avatar
thomwolf committed
18
import os
19
import re
Julien Plu's avatar
Julien Plu committed
20
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from typing import Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
22

Aymeric Augustin's avatar
Aymeric Augustin committed
23
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
24
import numpy as np
thomwolf's avatar
thomwolf committed
25
import tensorflow as tf
Julien Plu's avatar
Julien Plu committed
26
from tensorflow.python.keras import backend as K
thomwolf's avatar
thomwolf committed
27
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
28
29

from .configuration_utils import PretrainedConfig
30
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
31
from .generation_tf_utils import TFGenerationMixin
Lysandre Debut's avatar
Lysandre Debut committed
32
from .utils import logging
thomwolf's avatar
thomwolf committed
33

Aymeric Augustin's avatar
Aymeric Augustin committed
34

Lysandre Debut's avatar
Lysandre Debut committed
35
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
36

37

38
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
39
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
40
    A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
45
46
47
48
49
50
51
52
        Get the number of (optionally, trainable) parameters in the model.

        Args:
            only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to return only the number of trainable parameters

        Returns:
            :obj:`int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
53
54
55
56
57
58
59
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


60
def keras_serializable(cls):
61
62
63
64
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
65
66
67
68
69

    1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
       serialization time.
    2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
       time) and convert it to a config object for the actual layer initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
       need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
76
77
78
79

    Args:
        cls (a :obj:`tf.keras.layers.Layers subclass`):
            Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
            its initializer.

    Returns:
        The same class object, with modifications for Keras deserialization.
80
    """
81
    initializer = cls.__init__
82

83
84
85
86
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

87
    @functools.wraps(initializer)
88
    def wrapped_init(self, *args, **kwargs):
89
90
91
92
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
93
            initializer(self, config, *args, **kwargs)
94
95
96
97
98
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
99
        else:
100
101
102
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
Julien Plu's avatar
Julien Plu committed
103
        self._kwargs = kwargs
104

105
106
107
108
109
110
111
112
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
113
            cfg["config"] = self._config.to_dict()
Julien Plu's avatar
Julien Plu committed
114
            cfg.update(self._kwargs)
115
116
117
118
            return cfg

        cls.get_config = get_config

119
    cls._keras_serializable = True
120
121
122
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
123
124


125
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
126
127
128
129
130
131
132
133
134
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

135
136
137
138
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
139
        # make sure only labels that are not equal to -100 do not affect loss
140
        active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
141
142
143
144
145
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
        return loss_fn(labels, reduced_logits)


Julien Plu's avatar
Julien Plu committed
146
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
147
    """
148
    Loss function suitable for question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
149
150
    """

Julien Plu's avatar
Julien Plu committed
151
152
153
154
155
156
157
158
159
160
161
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
168
169
170
    """
    Loss function suitable for token classification.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

Julien Plu's avatar
Julien Plu committed
171
172
173
174
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
175
176
        # make sure only labels that are not equal to -100
        # are taken into account as loss
177
        if tf.math.reduce_any(labels == -1):
Julien Plu's avatar
Julien Plu committed
178
179
180
181
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
Julien Plu's avatar
Julien Plu committed
182
183
184
185
186
187
188
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
192
    """
    Loss function suitable for sequence classification.
    """

Julien Plu's avatar
Julien Plu committed
193
    def compute_loss(self, labels, logits):
194
        if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
Julien Plu's avatar
Julien Plu committed
195
196
197
198
199
200
201
202
203
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
    """Loss function suitable for multiple choice tasks."""


class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
210
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
211

Lysandre's avatar
Lysandre committed
212
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
213

Lysandre's avatar
Lysandre committed
214
215
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """
Julien Plu's avatar
Julien Plu committed
216
217


Julien Plu's avatar
Julien Plu committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
    """
    Detect missing and unexpected layers.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
    missing_layers = []
    unexpected_layers = []

    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_layers = list(model_layer_names - saved_layer_names)
        unexpected_layers = list(saved_layer_names - model_layer_names)

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                saved_weight_names_set = set(
                    "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
                )
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                symbolic_weights_names = set(
                    "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
                )
                missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
                unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

    return missing_layers, unexpected_layers


def load_tf_weights(model, resolved_archive_file):
    """
    Load the TF weights from a H5 file.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.
    """
    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        weight_value_tuples = []

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                saved_weight_names_values = {}

                for weight_name in saved_weight_names:
                    name = "/".join(weight_name.split("/")[1:])
                    saved_weight_names_values[name] = np.asarray(g[weight_name])

                for symbolic_weight in symbolic_weights:
                    splited_layers = symbolic_weight.name.split("/")[1:]
                    symbolic_weight_name = "/".join(splited_layers)

                    if symbolic_weight_name in saved_weight_names_values:
                        saved_weight_value = saved_weight_names_values[symbolic_weight_name]

                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
                            except AssertionError as e:
                                e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
                                raise e
                        else:
                            array = saved_weight_value

                        weight_value_tuples.append((symbolic_weight, array))

    K.batch_set_value(weight_value_tuples)


303
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
304
305
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
306

307
308
    :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
309

310
311
        * resize the input embeddings,
        * prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
312

313
    Class attributes (overridden by derived classes):
Sylvain Gugger's avatar
Sylvain Gugger committed
314

315
316
317
318
        - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
          :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
        - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
          derived classes of the same architecture adding modules on top of the base model.
Julien Plu's avatar
Julien Plu committed
319
320
        - **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
          from the model when loading the model weights (and avoid unnecessary warnings).
Sylvain Gugger's avatar
Sylvain Gugger committed
321
322
        - **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
          ignore from the weights when loading the model weights (and avoid unnecessary warnings).
thomwolf's avatar
thomwolf committed
323
324
325
    """
    config_class = None
    base_model_prefix = ""
326
    authorized_missing_keys = None
Julien Plu's avatar
Julien Plu committed
327
    authorized_unexpected_keys = None
thomwolf's avatar
thomwolf committed
328

329
    @property
330
331
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
332
333
334
335
        Dummy inputs to build the network.

        Returns:
            :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
336
        """
337
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
338
339

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
340
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
341
342
343
344
345
346
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
347
348
                )
            )
349
        # Save config and origin of the pretrained weights if given in model
thomwolf's avatar
thomwolf committed
350
        self.config = config
351
        self.name_or_path = config.name_or_path
thomwolf's avatar
thomwolf committed
352

353
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
354
355
356
357
        """
        Returns the model's input embeddings.

        Returns:
358
            :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
359
360
361
362
363
364
365
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

366
367
    def set_input_embeddings(self, value):
        """
368
        Set model's input embeddings.
369
370
371
372
373
374
375
376
377
378
379

        Args:
            value (:obj:`tf.keras.layers.Layer`):
                A module mapping vocabulary to hidden states.
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError

380
    def get_output_embeddings(self) -> tf.keras.layers.Layer:
381
382
383
384
        """
        Returns the model's output embeddings.

        Returns:
385
            :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
386
387
388
        """
        return None  # Overwrite for models with output embeddings

389
390
391
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
        Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
392

393
        Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
394

395
396
397
398
399
400
401
402
403
        Arguments:
            new_num_tokens (:obj:`int`, `optional`):
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
                just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
                anything.

        Return:
            :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        """
        model_embeds = self._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds

        return model_embeds

    def _resize_token_embeddings(self, new_num_tokens):
        # get_input_embeddings and set_input_embeddings need to be implemented in base layer.
        base_model = getattr(self, self.base_model_prefix, self)
        old_embeddings = base_model.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        base_model.set_input_embeddings(new_embeddings)
        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens
        return base_model.get_input_embeddings()

    def _get_word_embeddings(self, embeddings):
        if hasattr(embeddings, "word_embeddings"):
            # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
            return embeddings.word_embeddings
        elif hasattr(embeddings, "weight"):
            # TFSharedEmbeddings
            return embeddings.weight
        else:
            raise ValueError("word embedding is not defined.")

432
433
434
435
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
436
437

        Args:
438
439
440
            old_embeddings (:obj:`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (:obj:`int`, `optional`):
thomwolf's avatar
thomwolf committed
441
                New number of tokens in the embedding matrix.
442
443
444
445
446
447
448
449

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
                :obj:`tf.Variable`` module of the model wihtout doing anything.

        Return:
            :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
            :obj:`new_num_tokens` is :obj:`None`
thomwolf's avatar
thomwolf committed
450
        """
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        word_embeddings = self._get_word_embeddings(old_embeddings)
        if new_num_tokens is None:
            return word_embeddings
        old_num_tokens, old_embedding_dim = word_embeddings.shape
        if old_num_tokens == new_num_tokens:
            return word_embeddings

        # initialize new embeddings
        # todo: initializer range is not always passed in config.
        init_range = getattr(self.config, "initializer_range", 0.02)
        new_embeddings = self.add_weight(
            "weight",
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_weights = new_embeddings.numpy()
thomwolf's avatar
thomwolf committed
468

469
470
471
472
        # Copy token embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
        new_embeddings.assign(init_weights)
thomwolf's avatar
thomwolf committed
473

474
        return new_embeddings
thomwolf's avatar
thomwolf committed
475
476

    def prune_heads(self, heads_to_prune):
477
478
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
479

480
481
        Arguments:
            heads_to_prune (:obj:`Dict[int, List[int]]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
482
483
484
                Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list of
                heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads
                0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
485
486
487
488
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
489
490
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
Sylvain Gugger's avatar
Sylvain Gugger committed
491
        :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
492
493
494
495

        Arguments:
            save_directory (:obj:`str`):
                Directory to which to save. Will be created if it doesn't exist.
thomwolf's avatar
thomwolf committed
496
        """
497
498
499
500
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
            return
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
501
502
503
504
505
506
507

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
508
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
509
510
511

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
512
513
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
514

515
516
517
        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
518

519
520
        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.
thomwolf's avatar
thomwolf committed
521
522

        Parameters:
523
524
525
526
527
528
529
530
531
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
                      ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
                      ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
Sylvain Gugger's avatar
Sylvain Gugger committed
532
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `shortcut name` string of a
                      pretrained model).
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
                      by suppling the save directory.
                    - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
569
570
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
571
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
572
                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
573
574
575
576
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
            use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
Sylvain Gugger's avatar
Sylvain Gugger committed
577
                our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
578
            mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Sylvain Gugger's avatar
Sylvain Gugger committed
579
580
581
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
                Please refer to the mirror site for more information.
582
583
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
584
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
585
586
587
588
589
590
591
592
593
594
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
595
596
597

        Examples::

598
599
600
601
602
603
604
605
606
607
608
            >>> from transformers import BertConfig, TFBertModel
            >>> # Download model and configuration from S3 and cache.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = TFBertModel.from_pretrained('./test/saved_model/')
            >>> # Update configuration during loading.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
            >>> assert model.config.output_attentions == True
            >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            >>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
609
610

        """
611
612
613
614
615
616
617
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
618
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
619
        use_cdn = kwargs.pop("use_cdn", True)
620
        mirror = kwargs.pop("mirror", None)
thomwolf's avatar
thomwolf committed
621

622
623
624
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
625
            config, model_kwargs = cls.config_class.from_pretrained(
626
627
628
629
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
630
                force_download=force_download,
631
                resume_download=resume_download,
632
633
                proxies=proxies,
                local_files_only=local_files_only,
634
                **kwargs,
thomwolf's avatar
thomwolf committed
635
636
637
638
639
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
640
        if pretrained_model_name_or_path is not None:
641
            if os.path.isdir(pretrained_model_name_or_path):
642
643
644
645
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint in priority if from_pt
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
thomwolf's avatar
thomwolf committed
646
647
648
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                else:
649
650
651
652
653
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
654
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
655
                archive_file = pretrained_model_name_or_path
656
657
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
658
            else:
thomwolf's avatar
thomwolf committed
659
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
660
661
662
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
                    use_cdn=use_cdn,
663
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
664
                )
thomwolf's avatar
thomwolf committed
665
666

            try:
667
                # Load from URL or cache if already cached
668
669
670
671
672
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
673
674
                    resume_download=resume_download,
                    local_files_only=local_files_only,
675
                )
676
677
678
679
680
681
682
683
684
                if resolved_archive_file is None:
                    raise EnvironmentError
            except EnvironmentError:
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
thomwolf's avatar
thomwolf committed
685
686
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
687
            else:
688
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
689
        else:
thomwolf's avatar
thomwolf committed
690
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
691

692
693
        config.name_or_path = pretrained_model_name_or_path

thomwolf's avatar
thomwolf committed
694
695
696
697
        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
Julien Plu's avatar
Julien Plu committed
698
699
            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

thomwolf's avatar
thomwolf committed
700
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
701
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
702

703
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
704

thomwolf's avatar
thomwolf committed
705
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
706
707
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
708
        try:
Julien Plu's avatar
Julien Plu committed
709
            load_tf_weights(model, resolved_archive_file)
710
        except OSError:
711
712
713
714
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
715

716
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
717

Julien Plu's avatar
Julien Plu committed
718
        missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
thomwolf's avatar
thomwolf committed
719

720
721
722
723
        if cls.authorized_missing_keys is not None:
            for pat in cls.authorized_missing_keys:
                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

Julien Plu's avatar
Julien Plu committed
724
725
726
727
        if cls.authorized_unexpected_keys is not None:
            for pat in cls.authorized_unexpected_keys:
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

728
729
        if len(unexpected_keys) > 0:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
730
                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
731
732
733
734
735
736
737
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
Julien Plu's avatar
Julien Plu committed
738
739
            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")

thomwolf's avatar
thomwolf committed
740
        if len(missing_keys) > 0:
741
            logger.warning(
Julien Plu's avatar
Julien Plu committed
742
                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
743
744
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
745
            )
746
747
        else:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
748
                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
749
                f"If your task is similar to the task the model of the checkpoint was trained on, "
750
                f"you can already use {model.__class__.__name__} for predictions without further training."
751
            )
Julien Plu's avatar
Julien Plu committed
752

thomwolf's avatar
thomwolf committed
753
        if output_loading_info:
Julien Plu's avatar
Julien Plu committed
754
755
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}

thomwolf's avatar
thomwolf committed
756
757
            return model, loading_info

thomwolf's avatar
thomwolf committed
758
        return model
thomwolf's avatar
WIP  
thomwolf committed
759

760

thomwolf's avatar
WIP  
thomwolf committed
761
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (:obj:`int`):
            The number of output features.
        nx (:obj:`int`):
            The number of input features.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

thomwolf's avatar
thomwolf committed
778
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
779
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
780
        self.nf = nf
thomwolf's avatar
thomwolf committed
781
        self.nx = nx
thomwolf's avatar
thomwolf committed
782
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
783
784
785

    def build(self, input_shape):
        self.weight = self.add_weight(
786
787
788
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
789

thomwolf's avatar
WIP  
thomwolf committed
790
    def call(self, x):
thomwolf's avatar
thomwolf committed
791
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
792

thomwolf's avatar
thomwolf committed
793
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
794
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
795
796

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
797

thomwolf's avatar
WIP  
thomwolf committed
798
        return x
thomwolf's avatar
thomwolf committed
799
800


thomwolf's avatar
thomwolf committed
801
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
802
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
803
    Construct shared token embeddings.
804

Sylvain Gugger's avatar
Sylvain Gugger committed
805
806
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
807
808
809

    Args:
        vocab_size (:obj:`int`):
810
            The size of the vocabulary, e.g., the number of unique tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
811
812
813
814
815
816
817
818
819
820
        hidden_size (:obj:`int`):
            The size of the embedding vectors.
        initializer_range (:obj:`float`, `optional`):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            :math:`1/\sqrt{hidden\_size}`.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
821
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
822
823
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
824
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
825
826

    def build(self, input_shape):
Sylvain Gugger's avatar
Sylvain Gugger committed
827
828
829
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
thomwolf's avatar
thomwolf committed
830
831
        """
        self.weight = self.add_weight(
832
833
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
834
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
835

Julien Plu's avatar
Julien Plu committed
836
837
838
839
840
841
842
843
844
845
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
846
847
848
849
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
850
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
851
852
853
854
855
856
857
858
            inputs (:obj:`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.

                In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
            mode (:obj:`str`, defaults to :obj:`"embedding"`):
               A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
               should be used as an embedding layer, the second one that the layer should be used as a linear decoder.

thomwolf's avatar
thomwolf committed
859
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
860
            :obj:`tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape
Sylvain Gugger's avatar
Sylvain Gugger committed
861
862
            :obj:`[batch_size, length, embedding_size]`.

863
            In linear mode, the output is a float32 with shape :obj:`[batch_size, length, vocab_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
864

thomwolf's avatar
thomwolf committed
865
        Raises:
Sylvain Gugger's avatar
Sylvain Gugger committed
866
            ValueError: if :obj:`mode` is not valid.
867

Sylvain Gugger's avatar
Sylvain Gugger committed
868
869
        Shared weights logic is adapted from `here
        <https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
thomwolf's avatar
thomwolf committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
884
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
885

Julien Plu's avatar
Julien Plu committed
886
887
888
889
890
891
892
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
893
894
895
896
897
898
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
899
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
900
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
901
902
903
904
    Compute a single vector summary of a sequence hidden states.

    Args:
        config (:class:`~transformers.PretrainedConfig`):
Sylvain Gugger's avatar
Sylvain Gugger committed
905
906
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
Sylvain Gugger's avatar
Sylvain Gugger committed
907
908
909
910
911
912
913
914
915
916
917
918

            - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

                - :obj:`"last"` -- Take the last token hidden state (like XLNet)
                - :obj:`"first"` -- Take the first token hidden state (like Bert)
                - :obj:`"mean"` -- Take the mean of all tokens hidden states
                - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - :obj:`"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
              :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
Sylvain Gugger's avatar
Sylvain Gugger committed
919
            - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
Sylvain Gugger's avatar
Sylvain Gugger committed
920
921
922
923
924
925
926
927
928
              output, another string or :obj:`None` will add no activation.
            - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
              activation.
            - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
              activation.

        initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
929
    """
930

Sylvain Gugger's avatar
Sylvain Gugger committed
931
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
932
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
933

934
935
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
936
937
938
939
940
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

941
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
942
        if self.has_summary:
943
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
944
945
946
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
947
948
949
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
950

951
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
952
        if self.has_activation:
953
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
954

955
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
956
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
957
958
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

959
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
960
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
961
962
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
963
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
964
965
966
967
968
969
970
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
971
            hidden_states = inputs.get("hidden_states")
972
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
973

974
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
975
            output = hidden_states[:, -1]
976
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
977
            output = hidden_states[:, 0]
978
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
979
            output = tf.reduce_mean(hidden_states, axis=1)
980
        elif self.summary_type == "cls_index":
981
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
982
            if cls_index is None:
983
984
985
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
986
987
988
989
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
990
991
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
992
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
993
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
994
995
996
997
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
998
999
            raise NotImplementedError

1000
1001
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1002

1003
        if self.has_summary:
1004
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1005

1006
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1007
1008
            output = self.activation(output)

1009
1010
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1011
1012
1013

        return output

1014

Sylvain Gugger's avatar
Sylvain Gugger committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
def shape_list(x: tf.Tensor) -> List[int]:
    """
    Deal with dynamic shape in tensorflow cleanly.

    Args:
        x (:obj:`tf.Tensor`): The tensor we want the shape of.

    Returns:
        :obj:`List[int]`: The shape of the tensor as a list.
    """
thomwolf's avatar
thomwolf committed
1025
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
1026
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
1027
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1028

1029

Sylvain Gugger's avatar
Sylvain Gugger committed
1030
1031
1032
1033
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
    Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.

Julien Chaumond's avatar
Julien Chaumond committed
1034
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1035
1036
        initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.

Julien Chaumond's avatar
Julien Chaumond committed
1037
    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1038
        :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
1039
1040
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
1041
1042


Sylvain Gugger's avatar
Sylvain Gugger committed
1043
1044
1045
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
    """
    Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
1046
    cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
1047
1048

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1049
1050
1051
1052
1053
1054
1055
        bool_variable (:obj:`Union[tf.Tensor, bool]`):
            The variable to convert to a boolean.
        default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
            The default value to use in case the tensor has no numpy attribute.

    Returns:
        :obj:`bool`: The converted value.
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    """
    # if bool variable is tensor and has numpy value
    if tf.is_tensor(bool_variable):
        if hasattr(bool_variable, "numpy"):
            return bool(bool_variable.numpy())
        elif default_tensor_to_true:
            return True

    # else variable is bool
    return bool_variable
Sam Shleifer's avatar
Sam Shleifer committed
1066
1067
1068
1069


class TFWrappedEmbeddings:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1070
1071
1072
    this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
    weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
    saving/storing the correct weights
Sam Shleifer's avatar
Sam Shleifer committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    """

    def __init__(self, layer, abs_scope_name=None):
        self._layer = layer
        self._abs_scope_name = abs_scope_name

    def call(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer.call(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer.call(inputs, mode)

    def __call__(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer(inputs, mode)