"vscode:/vscode.git/clone" did not exist on "ebfdb9ca62205279d5019ef1403877461b3b2da4"
modeling_tf_utils.py 50.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
17
import functools
thomwolf's avatar
thomwolf committed
18
import os
19
import re
Julien Plu's avatar
Julien Plu committed
20
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from typing import Dict, List, Optional, Union
thomwolf's avatar
thomwolf committed
22

Aymeric Augustin's avatar
Aymeric Augustin committed
23
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
24
import numpy as np
thomwolf's avatar
thomwolf committed
25
import tensorflow as tf
Julien Plu's avatar
Julien Plu committed
26
from tensorflow.python.keras import backend as K
thomwolf's avatar
thomwolf committed
27
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
28
29

from .configuration_utils import PretrainedConfig
30
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
31
from .generation_tf_utils import TFGenerationMixin
Lysandre Debut's avatar
Lysandre Debut committed
32
from .utils import logging
thomwolf's avatar
thomwolf committed
33

Aymeric Augustin's avatar
Aymeric Augustin committed
34

Lysandre Debut's avatar
Lysandre Debut committed
35
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
36

37

38
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
39
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
40
    A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
45
46
47
48
49
50
51
52
        Get the number of (optionally, trainable) parameters in the model.

        Args:
            only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to return only the number of trainable parameters

        Returns:
            :obj:`int`: The number of parameters.
Julien Chaumond's avatar
Julien Chaumond committed
53
54
55
56
57
58
59
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


60
def keras_serializable(cls):
61
62
63
64
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
Sylvain Gugger's avatar
Sylvain Gugger committed
65
66
67
68
69

    1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
       serialization time.
    2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
       time) and convert it to a config object for the actual layer initializer.
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
       need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.
Sylvain Gugger's avatar
Sylvain Gugger committed
72
73
74
75
76
77
78
79

    Args:
        cls (a :obj:`tf.keras.layers.Layers subclass`):
            Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
            its initializer.

    Returns:
        The same class object, with modifications for Keras deserialization.
80
    """
81
    initializer = cls.__init__
82

83
84
85
86
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

87
    @functools.wraps(initializer)
88
    def wrapped_init(self, *args, **kwargs):
89
90
91
92
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)

        if isinstance(config, dict):
            config = config_class.from_dict(config)
93
            initializer(self, config, *args, **kwargs)
94
95
96
97
98
        elif isinstance(config, PretrainedConfig):
            if len(args) > 0:
                initializer(self, *args, **kwargs)
            else:
                initializer(self, config, *args, **kwargs)
99
        else:
100
101
102
            raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")

        self._config = config
Julien Plu's avatar
Julien Plu committed
103
        self._kwargs = kwargs
104

105
106
107
108
109
110
111
112
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
113
            cfg["config"] = self._config.to_dict()
Julien Plu's avatar
Julien Plu committed
114
            cfg.update(self._kwargs)
115
116
117
118
            return cfg

        cls.get_config = get_config

119
    cls._keras_serializable = True
120
121
122
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
123
124


125
class TFCausalLanguageModelingLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
126
127
128
129
130
131
132
133
134
    """
    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

135
136
137
138
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
139
        # make sure only labels that are not equal to -100 do not affect loss
140
        active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
141
142
143
144
145
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
        return loss_fn(labels, reduced_logits)


Julien Plu's avatar
Julien Plu committed
146
class TFQuestionAnsweringLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
147
    """
148
    Loss function suitable for question answering.
Sylvain Gugger's avatar
Sylvain Gugger committed
149
150
    """

Julien Plu's avatar
Julien Plu committed
151
152
153
154
155
156
157
158
159
160
161
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
168
169
170
    """
    Loss function suitable for token classification.

    .. note::

        Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.

    """

Julien Plu's avatar
Julien Plu committed
171
172
173
174
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
175
176
        # make sure only labels that are not equal to -100
        # are taken into account as loss
177
        if tf.math.reduce_any(labels == -1):
Julien Plu's avatar
Julien Plu committed
178
179
180
181
            warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
            active_loss = tf.reshape(labels, (-1,)) != -1
        else:
            active_loss = tf.reshape(labels, (-1,)) != -100
Julien Plu's avatar
Julien Plu committed
182
183
184
185
186
187
188
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
191
192
    """
    Loss function suitable for sequence classification.
    """

Julien Plu's avatar
Julien Plu committed
193
    def compute_loss(self, labels, logits):
194
        if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
Julien Plu's avatar
Julien Plu committed
195
196
197
198
199
200
201
202
203
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
    """Loss function suitable for multiple choice tasks."""


class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
    """
Lysandre's avatar
Lysandre committed
210
    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
211

Lysandre's avatar
Lysandre committed
212
    .. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
213

Lysandre's avatar
Lysandre committed
214
215
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """
Julien Plu's avatar
Julien Plu committed
216
217


218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
class TFNextSentencePredictionLoss:
    """
    Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.

    .. note::
         Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
    """

    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        # make sure only labels that are not equal to -100
        # are taken into account as loss
        next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
        next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
        next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

        return loss_fn(next_sentence_label, next_sentence_reduced_logits)


Julien Plu's avatar
Julien Plu committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
    """
    Detect missing and unexpected layers.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.

    Returns:
        Two lists, one for the missing layers, and another one for the unexpected layers.
    """
    missing_layers = []
    unexpected_layers = []

    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        model_layer_names = set(layer.name for layer in model.layers)
        missing_layers = list(model_layer_names - saved_layer_names)
        unexpected_layers = list(saved_layer_names - model_layer_names)

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                saved_weight_names_set = set(
                    "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
                )
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                symbolic_weights_names = set(
                    "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
                )
                missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
                unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

    return missing_layers, unexpected_layers


def load_tf_weights(model, resolved_archive_file):
    """
    Load the TF weights from a H5 file.

    Args:
        model (:obj:`tf.keras.models.Model`):
            The model to load the weights into.
        resolved_archive_file (:obj:`str`):
            The location of the H5 file.
    """
    with h5py.File(resolved_archive_file, "r") as f:
        saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
        weight_value_tuples = []

        for layer in model.layers:
            if layer.name in saved_layer_names:
                g = f[layer.name]
                saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
                saved_weight_names_values = {}

                for weight_name in saved_weight_names:
                    name = "/".join(weight_name.split("/")[1:])
                    saved_weight_names_values[name] = np.asarray(g[weight_name])

                for symbolic_weight in symbolic_weights:
                    splited_layers = symbolic_weight.name.split("/")[1:]
                    symbolic_weight_name = "/".join(splited_layers)

                    if symbolic_weight_name in saved_weight_names_values:
                        saved_weight_value = saved_weight_names_values[symbolic_weight_name]

                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
                            try:
                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
                            except AssertionError as e:
                                e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
                                raise e
                        else:
                            array = saved_weight_value

                        weight_value_tuples.append((symbolic_weight, array))

    K.batch_set_value(weight_value_tuples)


324
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
325
326
    r"""
    Base class for all TF models.
thomwolf's avatar
thomwolf committed
327

328
329
    :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
thomwolf's avatar
thomwolf committed
330

331
332
        * resize the input embeddings,
        * prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
333

334
    Class attributes (overridden by derived classes):
Sylvain Gugger's avatar
Sylvain Gugger committed
335

336
337
338
339
        - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
          :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
        - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
          derived classes of the same architecture adding modules on top of the base model.
Julien Plu's avatar
Julien Plu committed
340
341
        - **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
          from the model when loading the model weights (and avoid unnecessary warnings).
Sylvain Gugger's avatar
Sylvain Gugger committed
342
343
        - **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
          ignore from the weights when loading the model weights (and avoid unnecessary warnings).
thomwolf's avatar
thomwolf committed
344
345
346
    """
    config_class = None
    base_model_prefix = ""
347
    authorized_missing_keys = None
Julien Plu's avatar
Julien Plu committed
348
    authorized_unexpected_keys = None
thomwolf's avatar
thomwolf committed
349

350
    @property
351
352
    def dummy_inputs(self) -> Dict[str, tf.Tensor]:
        """
Julien Plu's avatar
Julien Plu committed
353
354
355
356
        Dummy inputs to build the network.

        Returns:
            :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
357
        """
358
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
359
360

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
361
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
362
363
364
365
366
367
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
368
369
                )
            )
370
        # Save config and origin of the pretrained weights if given in model
thomwolf's avatar
thomwolf committed
371
        self.config = config
372
        self.name_or_path = config.name_or_path
thomwolf's avatar
thomwolf committed
373

374
    def get_input_embeddings(self) -> tf.keras.layers.Layer:
375
376
377
378
        """
        Returns the model's input embeddings.

        Returns:
379
            :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
380
381
382
383
384
385
386
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

387
388
    def set_input_embeddings(self, value):
        """
389
        Set model's input embeddings.
390
391
392
393
394
395
396
397
398
399
400

        Args:
            value (:obj:`tf.keras.layers.Layer`):
                A module mapping vocabulary to hidden states.
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError

401
    def get_output_embeddings(self) -> tf.keras.layers.Layer:
402
403
404
405
        """
        Returns the model's output embeddings.

        Returns:
406
            :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
407
408
409
        """
        return None  # Overwrite for models with output embeddings

410
411
412
    def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
        """
        Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
413

414
        Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
415

416
417
418
419
        Arguments:
            new_num_tokens (:obj:`int`, `optional`):
                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
                vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
420
                just returns a pointer to the input tokens :obj:`tf.Variable` module of the model without doing
421
422
423
424
                anything.

        Return:
            :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        """
        model_embeds = self._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds

        return model_embeds

    def _resize_token_embeddings(self, new_num_tokens):
        # get_input_embeddings and set_input_embeddings need to be implemented in base layer.
        base_model = getattr(self, self.base_model_prefix, self)
        old_embeddings = base_model.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        base_model.set_input_embeddings(new_embeddings)
        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens
        return base_model.get_input_embeddings()

    def _get_word_embeddings(self, embeddings):
        if hasattr(embeddings, "word_embeddings"):
            # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
            return embeddings.word_embeddings
        elif hasattr(embeddings, "weight"):
            # TFSharedEmbeddings
            return embeddings.weight
        else:
            raise ValueError("word embedding is not defined.")

453
454
455
456
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
        """
        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
        initialized vectors at the end. Reducing the size will remove vectors from the end
thomwolf's avatar
thomwolf committed
457
458

        Args:
459
460
461
            old_embeddings (:obj:`tf.Variable`):
                Old embeddings to be resized.
            new_num_tokens (:obj:`int`, `optional`):
thomwolf's avatar
thomwolf committed
462
                New number of tokens in the embedding matrix.
463
464
465

                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
                vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
466
                :obj:`tf.Variable`` module of the model without doing anything.
467
468
469
470

        Return:
            :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
            :obj:`new_num_tokens` is :obj:`None`
thomwolf's avatar
thomwolf committed
471
        """
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        word_embeddings = self._get_word_embeddings(old_embeddings)
        if new_num_tokens is None:
            return word_embeddings
        old_num_tokens, old_embedding_dim = word_embeddings.shape
        if old_num_tokens == new_num_tokens:
            return word_embeddings

        # initialize new embeddings
        # todo: initializer range is not always passed in config.
        init_range = getattr(self.config, "initializer_range", 0.02)
        new_embeddings = self.add_weight(
            "weight",
            shape=[new_num_tokens, old_embedding_dim],
            initializer=get_initializer(init_range),
            dtype=tf.float32,
        )
        init_weights = new_embeddings.numpy()
thomwolf's avatar
thomwolf committed
489

490
491
492
493
        # Copy token embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
        new_embeddings.assign(init_weights)
thomwolf's avatar
thomwolf committed
494

495
        return new_embeddings
thomwolf's avatar
thomwolf committed
496
497

    def prune_heads(self, heads_to_prune):
498
499
        """
        Prunes heads of the base model.
thomwolf's avatar
thomwolf committed
500

501
502
        Arguments:
            heads_to_prune (:obj:`Dict[int, List[int]]`):
Sylvain Gugger's avatar
Sylvain Gugger committed
503
504
505
                Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list of
                heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads
                0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
506
507
508
509
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
510
511
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
Sylvain Gugger's avatar
Sylvain Gugger committed
512
        :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
513
514
515
516

        Arguments:
            save_directory (:obj:`str`):
                Directory to which to save. Will be created if it doesn't exist.
thomwolf's avatar
thomwolf committed
517
        """
518
519
520
521
        if os.path.isfile(save_directory):
            logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
            return
        os.makedirs(save_directory, exist_ok=True)
thomwolf's avatar
thomwolf committed
522
523
524
525
526
527
528

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
529
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
530
531
532

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
533
534
        r"""
        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
535

536
537
538
        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.
thomwolf's avatar
thomwolf committed
539

540
541
        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.
thomwolf's avatar
thomwolf committed
542
543

        Parameters:
544
545
546
547
548
549
550
551
552
            pretrained_model_name_or_path (:obj:`str`, `optional`):
                Can be either:

                    - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
                      ``bert-base-uncased``.
                    - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
                      ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
Sylvain Gugger's avatar
Sylvain Gugger committed
553
                    - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
                      this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
                      as ``config`` argument. This loading path is slower than converting the PyTorch model in a
                      TensorFlow model using the provided conversion scripts and loading the TensorFlow model
                      afterwards.
                    - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
                      arguments ``config`` and ``state_dict``).
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `shortcut name` string of a
                      pretrained model).
                    - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
574
575
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                      configuration JSON file named `config.json` is found in the directory.
            from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch state_dict save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            cache_dir (:obj:`str`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies: (:obj:`Dict[str, str], `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
590
591
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
592
            output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Sylvain Gugger's avatar
Sylvain Gugger committed
593
                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
594
595
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (e.g., not try doanloading the model).
Julien Chaumond's avatar
Julien Chaumond committed
596
597
598
599
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
600
            mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Sylvain Gugger's avatar
Sylvain Gugger committed
601
602
603
                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
                Please refer to the mirror site for more information.
604
605
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
606
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
607
608
609
610
611
612
613
614
615
616
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
617
618
619

        Examples::

620
621
622
623
624
625
626
627
628
629
630
            >>> from transformers import BertConfig, TFBertModel
            >>> # Download model and configuration from S3 and cache.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = TFBertModel.from_pretrained('./test/saved_model/')
            >>> # Update configuration during loading.
            >>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
            >>> assert model.config.output_attentions == True
            >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
            >>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
thomwolf's avatar
thomwolf committed
631
632

        """
633
634
635
636
637
638
639
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
640
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
641
        revision = kwargs.pop("revision", None)
642
        mirror = kwargs.pop("mirror", None)
thomwolf's avatar
thomwolf committed
643

644
645
646
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
647
            config, model_kwargs = cls.config_class.from_pretrained(
648
649
650
651
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
652
                force_download=force_download,
653
                resume_download=resume_download,
654
655
                proxies=proxies,
                local_files_only=local_files_only,
Julien Chaumond's avatar
Julien Chaumond committed
656
                revision=revision,
657
                **kwargs,
thomwolf's avatar
thomwolf committed
658
659
660
661
662
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
663
        if pretrained_model_name_or_path is not None:
664
            if os.path.isdir(pretrained_model_name_or_path):
665
666
667
668
                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint in priority if from_pt
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
thomwolf's avatar
thomwolf committed
669
670
671
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                else:
672
673
674
675
676
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
677
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
678
                archive_file = pretrained_model_name_or_path
679
680
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
681
            else:
thomwolf's avatar
thomwolf committed
682
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
683
684
                    pretrained_model_name_or_path,
                    filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
Julien Chaumond's avatar
Julien Chaumond committed
685
                    revision=revision,
686
                    mirror=mirror,
thomwolf's avatar
thomwolf committed
687
                )
thomwolf's avatar
thomwolf committed
688
689

            try:
690
                # Load from URL or cache if already cached
691
692
693
694
695
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
696
697
                    resume_download=resume_download,
                    local_files_only=local_files_only,
698
                )
Julien Chaumond's avatar
Julien Chaumond committed
699
700
            except EnvironmentError as err:
                logger.error(err)
701
702
703
704
705
706
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)
thomwolf's avatar
thomwolf committed
707
708
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
709
            else:
710
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
711
        else:
thomwolf's avatar
thomwolf committed
712
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
713

714
715
        config.name_or_path = pretrained_model_name_or_path

thomwolf's avatar
thomwolf committed
716
717
718
719
        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
Julien Plu's avatar
Julien Plu committed
720
721
            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

thomwolf's avatar
thomwolf committed
722
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
723
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
724

725
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
726

thomwolf's avatar
thomwolf committed
727
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
728
729
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
730
        try:
Julien Plu's avatar
Julien Plu committed
731
            load_tf_weights(model, resolved_archive_file)
732
        except OSError:
733
734
735
736
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
737

738
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
739

Julien Plu's avatar
Julien Plu committed
740
        missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
thomwolf's avatar
thomwolf committed
741

742
743
744
745
        if cls.authorized_missing_keys is not None:
            for pat in cls.authorized_missing_keys:
                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

Julien Plu's avatar
Julien Plu committed
746
747
748
749
        if cls.authorized_unexpected_keys is not None:
            for pat in cls.authorized_unexpected_keys:
                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

750
751
        if len(unexpected_keys) > 0:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
752
                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
753
754
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
755
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
756
757
758
759
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
Julien Plu's avatar
Julien Plu committed
760
761
            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")

thomwolf's avatar
thomwolf committed
762
        if len(missing_keys) > 0:
763
            logger.warning(
Julien Plu's avatar
Julien Plu committed
764
                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
765
766
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
767
            )
768
769
        else:
            logger.warning(
Julien Plu's avatar
Julien Plu committed
770
                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
771
                f"If your task is similar to the task the model of the checkpoint was trained on, "
772
                f"you can already use {model.__class__.__name__} for predictions without further training."
773
            )
Julien Plu's avatar
Julien Plu committed
774

thomwolf's avatar
thomwolf committed
775
        if output_loading_info:
Julien Plu's avatar
Julien Plu committed
776
777
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}

thomwolf's avatar
thomwolf committed
778
779
            return model, loading_info

thomwolf's avatar
thomwolf committed
780
        return model
thomwolf's avatar
WIP  
thomwolf committed
781

782

thomwolf's avatar
WIP  
thomwolf committed
783
class TFConv1D(tf.keras.layers.Layer):
Sylvain Gugger's avatar
Sylvain Gugger committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).

    Basically works like a linear layer but the weights are transposed.

    Args:
        nf (:obj:`int`):
            The number of output features.
        nx (:obj:`int`):
            The number of input features.
        initializer_range (:obj:`float`, `optional`, defaults to 0.02):
            The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

thomwolf's avatar
thomwolf committed
800
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
801
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
802
        self.nf = nf
thomwolf's avatar
thomwolf committed
803
        self.nx = nx
thomwolf's avatar
thomwolf committed
804
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
805
806
807

    def build(self, input_shape):
        self.weight = self.add_weight(
808
809
810
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
811

thomwolf's avatar
WIP  
thomwolf committed
812
    def call(self, x):
thomwolf's avatar
thomwolf committed
813
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
814

thomwolf's avatar
thomwolf committed
815
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
816
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
817
818

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
819

thomwolf's avatar
WIP  
thomwolf committed
820
        return x
thomwolf's avatar
thomwolf committed
821
822


thomwolf's avatar
thomwolf committed
823
class TFSharedEmbeddings(tf.keras.layers.Layer):
Stas Bekman's avatar
Stas Bekman committed
824
    r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
825
    Construct shared token embeddings.
826

Sylvain Gugger's avatar
Sylvain Gugger committed
827
828
    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
    modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
829
830
831

    Args:
        vocab_size (:obj:`int`):
832
            The size of the vocabulary, e.g., the number of unique tokens.
Sylvain Gugger's avatar
Sylvain Gugger committed
833
834
835
836
837
838
839
840
841
842
        hidden_size (:obj:`int`):
            The size of the embedding vectors.
        initializer_range (:obj:`float`, `optional`):
            The standard deviation to use when initializing the weights. If no value is provided, it will default to
            :math:`1/\sqrt{hidden\_size}`.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
    """

    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
843
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
844
845
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
846
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
847
848

    def build(self, input_shape):
Sylvain Gugger's avatar
Sylvain Gugger committed
849
850
851
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
thomwolf's avatar
thomwolf committed
852
853
        """
        self.weight = self.add_weight(
854
855
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
856
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
857

Julien Plu's avatar
Julien Plu committed
858
859
860
861
862
863
864
865
866
867
    def get_config(self):
        config = {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "initializer_range": self.initializer_range,
        }
        base_config = super().get_config()

        return dict(list(base_config.items()) + list(config.items()))

Sylvain Gugger's avatar
Sylvain Gugger committed
868
869
870
871
    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
        """
        Get token embeddings of inputs or decode final hidden state.

thomwolf's avatar
thomwolf committed
872
        Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
873
874
875
876
877
878
879
880
            inputs (:obj:`tf.Tensor`):
                In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.

                In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
            mode (:obj:`str`, defaults to :obj:`"embedding"`):
               A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
               should be used as an embedding layer, the second one that the layer should be used as a linear decoder.

thomwolf's avatar
thomwolf committed
881
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
882
            :obj:`tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
            :obj:`[batch_size, length, embedding_size]`.

885
            In linear mode, the output is a float32 with shape :obj:`[batch_size, length, vocab_size]`.
Sylvain Gugger's avatar
Sylvain Gugger committed
886

thomwolf's avatar
thomwolf committed
887
        Raises:
Sylvain Gugger's avatar
Sylvain Gugger committed
888
            ValueError: if :obj:`mode` is not valid.
889

Sylvain Gugger's avatar
Sylvain Gugger committed
890
891
        Shared weights logic is adapted from `here
        <https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
thomwolf's avatar
thomwolf committed
892
893
894
895
896
897
898
899
900
901
902
903
904
905
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """
Julien Plu's avatar
Julien Plu committed
906
        Computes logits by running inputs through a linear layer.
thomwolf's avatar
thomwolf committed
907

Julien Plu's avatar
Julien Plu committed
908
909
910
911
912
913
914
        Args:
            inputs: A float32 tensor with shape [..., hidden_size]

        Returns:
            float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
915
916
917
918
919
920
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
921
class TFSequenceSummary(tf.keras.layers.Layer):
Julien Plu's avatar
Julien Plu committed
922
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
923
924
925
926
    Compute a single vector summary of a sequence hidden states.

    Args:
        config (:class:`~transformers.PretrainedConfig`):
Sylvain Gugger's avatar
Sylvain Gugger committed
927
928
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):
Sylvain Gugger's avatar
Sylvain Gugger committed
929
930
931
932
933
934
935
936
937
938
939
940

            - **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:

                - :obj:`"last"` -- Take the last token hidden state (like XLNet)
                - :obj:`"first"` -- Take the first token hidden state (like Bert)
                - :obj:`"mean"` -- Take the mean of all tokens hidden states
                - :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - :obj:`"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
              :obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
Sylvain Gugger's avatar
Sylvain Gugger committed
941
            - **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
Sylvain Gugger's avatar
Sylvain Gugger committed
942
943
944
945
946
947
948
949
950
              output, another string or :obj:`None` will add no activation.
            - **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
              activation.
            - **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
              activation.

        initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
        kwargs:
            Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
thomwolf's avatar
thomwolf committed
951
    """
952

Sylvain Gugger's avatar
Sylvain Gugger committed
953
    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
954
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
955

956
957
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
958
959
960
961
962
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

963
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
964
        if self.has_summary:
965
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
966
967
968
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
969
970
971
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
972

973
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
974
        if self.has_activation:
975
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
976

977
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
978
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
979
980
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

981
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
982
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
983
984
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

Julien Plu's avatar
Julien Plu committed
985
    def call(self, inputs, cls_index=None, training=False):
thomwolf's avatar
thomwolf committed
986
987
988
989
990
991
992
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
993
            hidden_states = inputs.get("hidden_states")
994
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
995

996
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
997
            output = hidden_states[:, -1]
998
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
999
            output = hidden_states[:, 0]
1000
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
1001
            output = tf.reduce_mean(hidden_states, axis=1)
1002
        elif self.summary_type == "cls_index":
1003
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
1004
            if cls_index is None:
1005
1006
1007
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
1008
1009
1010
1011
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
1012
1013
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1014
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1015
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
1016
1017
1018
1019
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1020
1021
            raise NotImplementedError

1022
1023
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1024

1025
        if self.has_summary:
1026
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1027

1028
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1029
1030
            output = self.activation(output)

1031
1032
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1033
1034
1035

        return output

1036

Sylvain Gugger's avatar
Sylvain Gugger committed
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
def shape_list(x: tf.Tensor) -> List[int]:
    """
    Deal with dynamic shape in tensorflow cleanly.

    Args:
        x (:obj:`tf.Tensor`): The tensor we want the shape of.

    Returns:
        :obj:`List[int]`: The shape of the tensor as a list.
    """
thomwolf's avatar
thomwolf committed
1047
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
1048
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
1049
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1050

1051

Sylvain Gugger's avatar
Sylvain Gugger committed
1052
1053
1054
1055
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
    """
    Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.

Julien Chaumond's avatar
Julien Chaumond committed
1056
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1057
1058
        initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.

Julien Chaumond's avatar
Julien Chaumond committed
1059
    Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1060
        :obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
Julien Chaumond's avatar
Julien Chaumond committed
1061
1062
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
1063
1064


Sylvain Gugger's avatar
Sylvain Gugger committed
1065
1066
1067
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
    """
    Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
1068
    cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
1069
1070

    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
1071
1072
1073
1074
1075
1076
1077
        bool_variable (:obj:`Union[tf.Tensor, bool]`):
            The variable to convert to a boolean.
        default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
            The default value to use in case the tensor has no numpy attribute.

    Returns:
        :obj:`bool`: The converted value.
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
    """
    # if bool variable is tensor and has numpy value
    if tf.is_tensor(bool_variable):
        if hasattr(bool_variable, "numpy"):
            return bool(bool_variable.numpy())
        elif default_tensor_to_true:
            return True

    # else variable is bool
    return bool_variable
Sam Shleifer's avatar
Sam Shleifer committed
1088
1089
1090
1091


class TFWrappedEmbeddings:
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
1094
    this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
    weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
    saving/storing the correct weights
Sam Shleifer's avatar
Sam Shleifer committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
    """

    def __init__(self, layer, abs_scope_name=None):
        self._layer = layer
        self._abs_scope_name = abs_scope_name

    def call(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer.call(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer.call(inputs, mode)

    def __call__(self, inputs, mode="embedding"):
        if self._abs_scope_name is None:
            return self._layer(inputs, mode)

        # if an abs scope name is given to the embedding variable, call variable from absolute scope
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
            with tf.name_scope(abs_scope_name.original_name_scope):
                return self._layer(inputs, mode)