modeling_tf_utils.py 29.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""


import logging
import os

Aymeric Augustin's avatar
Aymeric Augustin committed
22
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
23
import numpy as np
thomwolf's avatar
thomwolf committed
24
import tensorflow as tf
thomwolf's avatar
thomwolf committed
25
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
26
27

from .configuration_utils import PretrainedConfig
28
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
29
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
thomwolf committed
30

Aymeric Augustin's avatar
Aymeric Augustin committed
31

thomwolf's avatar
thomwolf committed
32
33
logger = logging.getLogger(__name__)

34

Julien Chaumond's avatar
Julien Chaumond committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class TFModelUtils:
    """
    A few utilities for `tf.keras.Model`s, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the model.
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


class TFPreTrainedModel(tf.keras.Model, TFModelUtils):
thomwolf's avatar
thomwolf committed
51
52
    r""" Base class for all TF models.

53
        :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
54
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
55
56

        Class attributes (overridden by derived classes):
57
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
thomwolf's avatar
thomwolf committed
58
59
60
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

61
62
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
thomwolf's avatar
thomwolf committed
63
64
65
66
67
68
69
70
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
    """
    config_class = None
    pretrained_model_archive_map = {}
    base_model_prefix = ""

71
72
73
74
75
76
77
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
78
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
79
80

    def __init__(self, config, *inputs, **kwargs):
thomwolf's avatar
thomwolf committed
81
        super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
82
83
84
85
86
87
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
88
89
                )
            )
thomwolf's avatar
thomwolf committed
90
91
92
        # Save config in model
        self.config = config

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def get_input_embeddings(self):
        """ Get model's input embeddings
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

    def get_output_embeddings(self):
        """ Get model's output embeddings
            Return None if the model doesn't have output embeddings
        """
        return None  # Overwrite for models with output embeddings

thomwolf's avatar
thomwolf committed
108
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
109
        """ Build a resized Embedding Variable from a provided token Embedding Module.
thomwolf's avatar
thomwolf committed
110
111
112
113
114
115
116
117
118
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
thomwolf's avatar
thomwolf committed
119
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
120
121
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
122
123
        # if new_num_tokens is None:
        #     return old_embeddings
thomwolf's avatar
thomwolf committed
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        # if old_num_tokens == new_num_tokens:
        #     return old_embeddings

        # # Build new embeddings
        # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        # new_embeddings.to(old_embeddings.weight.device)

        # # initialize all new embeddings (in particular added tokens)
        # self._init_weights(new_embeddings)

        # # Copy word embeddings from the previous weights
        # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        # return new_embeddings
thomwolf's avatar
thomwolf committed
141
142
143
144
145
146
147
148

    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

        Arguments:

            new_num_tokens: (`optional`) int:
149
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
thomwolf's avatar
thomwolf committed
150
                If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
thomwolf's avatar
thomwolf committed
151

thomwolf's avatar
thomwolf committed
152
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            Pointer to the input tokens Embeddings Module of the model
        """
        raise NotImplementedError

    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
        """ Save a model and its configuration file to a directory, so that it
168
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
thomwolf's avatar
thomwolf committed
169
        """
170
171
172
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
thomwolf's avatar
thomwolf committed
173
174
175
176
177
178
179

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
180
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
181
182
183

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
184
        r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
185
186
187
188
189
190
191
192
193
194
195
196
197

        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
198
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
199
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
200
201
202
203
204
                - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

205
206
207
            config: (`optional`) one of:
                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
thomwolf's avatar
thomwolf committed
208
209
210
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
211
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
212
213
214
215
216
217
218
219
220
221
222
223
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            from_pt: (`optional`) boolean, default False:
                Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

224
225
226
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
227
228
229
230
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

thomwolf's avatar
thomwolf committed
231
232
233
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

thomwolf's avatar
thomwolf committed
234
235
236
237
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
238
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
239
240
241
242
243
244
245
246
247
248
249
250

        Examples::

            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

        """
251
252
253
254
255
256
257
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
thomwolf's avatar
thomwolf committed
258

259
260
261
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
262
            config, model_kwargs = cls.config_class.from_pretrained(
263
264
265
266
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
267
                force_download=force_download,
268
                resume_download=resume_download,
269
                **kwargs,
thomwolf's avatar
thomwolf committed
270
271
272
273
274
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
275
        if pretrained_model_name_or_path is not None:
thomwolf's avatar
thomwolf committed
276
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
277
278
279
280
281
282
283
284
285
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
286
287
288
289
290
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
291
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
292
                archive_file = pretrained_model_name_or_path
293
294
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
295
            else:
Julien Chaumond's avatar
Julien Chaumond committed
296
297
                archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=TF2_WEIGHTS_NAME)
                if from_pt:
298
299
300
                    raise EnvironmentError(
                        "Loading a TF model from a PyTorch checkpoint is not supported when using a model identifier name."
                    )
thomwolf's avatar
thomwolf committed
301
302
303

            # redirect to the cache, if necessary
            try:
304
305
306
307
308
309
310
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                )
thomwolf's avatar
thomwolf committed
311
312
            except EnvironmentError as e:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
313
                    logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
thomwolf's avatar
thomwolf committed
314
315
316
317
318
319
                else:
                    logger.error(
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url but couldn't find any file "
                        "associated to this path or url.".format(
                            pretrained_model_name_or_path,
320
321
322
323
                            ", ".join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                        )
                    )
thomwolf's avatar
thomwolf committed
324
325
326
                raise e
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
327
            else:
328
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
329
        else:
thomwolf's avatar
thomwolf committed
330
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
331
332
333
334
335
336

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
337
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
338

339
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
340

thomwolf's avatar
thomwolf committed
341
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
342
343
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
344
345
346
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
347
348
349
350
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
351

352
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
353

thomwolf's avatar
thomwolf committed
354
        # Check if the models are the same to output loading informations
355
356
357
358
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
thomwolf's avatar
thomwolf committed
359
360
361
362
363
364
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(missing_keys) > 0:
365
366
367
            logger.info(
                "Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
thomwolf's avatar
thomwolf committed
368
        if len(unexpected_keys) > 0:
369
370
371
            logger.info(
                "Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
thomwolf's avatar
thomwolf committed
372
        if len(error_msgs) > 0:
373
374
375
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
thomwolf's avatar
thomwolf committed
376
        if output_loading_info:
377
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
thomwolf's avatar
thomwolf committed
378
379
            return model, loading_info

thomwolf's avatar
thomwolf committed
380
        return model
thomwolf's avatar
WIP  
thomwolf committed
381

382

thomwolf's avatar
WIP  
thomwolf committed
383
class TFConv1D(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
384
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
thomwolf's avatar
WIP  
thomwolf committed
385
386
387
        """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
            Basically works like a Linear layer but the weights are transposed
        """
thomwolf's avatar
thomwolf committed
388
        super(TFConv1D, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
389
        self.nf = nf
thomwolf's avatar
thomwolf committed
390
        self.nx = nx
thomwolf's avatar
thomwolf committed
391
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
392
393
394

    def build(self, input_shape):
        self.weight = self.add_weight(
395
396
397
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
398

thomwolf's avatar
WIP  
thomwolf committed
399
    def call(self, x):
thomwolf's avatar
thomwolf committed
400
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
401

thomwolf's avatar
thomwolf committed
402
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
403
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
404
405

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
406

thomwolf's avatar
WIP  
thomwolf committed
407
        return x
thomwolf's avatar
thomwolf committed
408
409


thomwolf's avatar
thomwolf committed
410
411
412
class TFSharedEmbeddings(tf.keras.layers.Layer):
    """Construct shared token embeddings.
    """
413

thomwolf's avatar
thomwolf committed
414
415
416
417
    def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
        super(TFSharedEmbeddings, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
418
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
419
420
421
422
423
424
425

    def build(self, input_shape):
        """Build shared word embedding layer
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
426
427
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
thomwolf's avatar
thomwolf committed
428
429
430
431
432
433
434
435
436
437
438
439
440
        super(TFSharedEmbeddings, self).build(input_shape)

    def call(self, inputs, mode="embedding"):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
441

thomwolf's avatar
thomwolf committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [..., hidden_size]
            Returns:
                float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
class TFSequenceSummary(tf.keras.layers.Layer):
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
    """
486

thomwolf's avatar
thomwolf committed
487
    def __init__(self, config, initializer_range=0.02, **kwargs):
thomwolf's avatar
thomwolf committed
488
489
        super(TFSequenceSummary, self).__init__(**kwargs)

490
491
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
492
493
494
495
496
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

497
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
498
        if self.has_summary:
499
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
500
501
502
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
503
504
505
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
506

507
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
508
        if self.has_activation:
509
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
510

511
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
512
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
513
514
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

515
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
516
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

    def call(self, inputs, training=False):
        """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
                if summary_type == 'cls_index' and cls_index is None:
                    we take the last token of the sequence as classification token
        """
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
            cls_index = None
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
534
            hidden_states = inputs.get("hidden_states")
535
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
536

537
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
538
            output = hidden_states[:, -1]
539
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
540
            output = hidden_states[:, 0]
541
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
542
            output = tf.reduce_mean(hidden_states, axis=1)
543
        elif self.summary_type == "cls_index":
544
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
545
            if cls_index is None:
546
547
548
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
549
550
551
552
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
553
554
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
555
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
556
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
557
558
559
560
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
561
562
            raise NotImplementedError

563
564
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
565

566
        if self.has_summary:
567
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
568

569
        if self.has_activation:
thomwolf's avatar
thomwolf committed
570
571
            output = self.activation(output)

572
573
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
574
575
576

        return output

577

thomwolf's avatar
thomwolf committed
578
579
580
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
581
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
582
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
583

584

thomwolf's avatar
thomwolf committed
585
def get_initializer(initializer_range=0.02):
Julien Chaumond's avatar
Julien Chaumond committed
586
587
588
589
590
591
592
    """Creates a `tf.initializers.truncated_normal` with the given range.
    Args:
        initializer_range: float, initializer range for stddev.
    Returns:
        TruncatedNormal initializer with stddev = `initializer_range`.
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)