modeling_tf_utils.py 29.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""


import logging
import os

Aymeric Augustin's avatar
Aymeric Augustin committed
22
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
23
import numpy as np
thomwolf's avatar
thomwolf committed
24
import tensorflow as tf
thomwolf's avatar
thomwolf committed
25
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
26
27

from .configuration_utils import PretrainedConfig
28
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
29
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
thomwolf committed
30

Aymeric Augustin's avatar
Aymeric Augustin committed
31

thomwolf's avatar
thomwolf committed
32
33
logger = logging.getLogger(__name__)

34

35
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    """
    A few utilities for `tf.keras.Model`s, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the model.
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


50
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
thomwolf's avatar
thomwolf committed
51
52
    r""" Base class for all TF models.

53
        :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
54
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
55
56

        Class attributes (overridden by derived classes):
57
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
thomwolf's avatar
thomwolf committed
58
59
60
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

61
62
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
thomwolf's avatar
thomwolf committed
63
64
65
66
67
68
69
70
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
    """
    config_class = None
    pretrained_model_archive_map = {}
    base_model_prefix = ""

71
72
73
74
75
76
77
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
78
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
79
80

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
81
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
82
83
84
85
86
87
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
88
89
                )
            )
thomwolf's avatar
thomwolf committed
90
91
92
        # Save config in model
        self.config = config

93
    def get_input_embeddings(self):
94
95
96
97
98
99
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping vocabulary to hidden states.
100
101
102
103
104
105
106
107
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

    def get_output_embeddings(self):
108
109
110
111
112
113
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping hidden states to vocabulary.
114
115
116
        """
        return None  # Overwrite for models with output embeddings

thomwolf's avatar
thomwolf committed
117
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
118
        """ Build a resized Embedding Variable from a provided token Embedding Module.
thomwolf's avatar
thomwolf committed
119
120
121
122
123
124
125
126
127
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
thomwolf's avatar
thomwolf committed
128
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
129
130
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
131
132
        # if new_num_tokens is None:
        #     return old_embeddings
thomwolf's avatar
thomwolf committed
133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        # if old_num_tokens == new_num_tokens:
        #     return old_embeddings

        # # Build new embeddings
        # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        # new_embeddings.to(old_embeddings.weight.device)

        # # initialize all new embeddings (in particular added tokens)
        # self._init_weights(new_embeddings)

        # # Copy word embeddings from the previous weights
        # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        # return new_embeddings
thomwolf's avatar
thomwolf committed
150
151
152
153
154
155
156
157

    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

        Arguments:

            new_num_tokens: (`optional`) int:
158
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
thomwolf's avatar
thomwolf committed
159
                If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
thomwolf's avatar
thomwolf committed
160

thomwolf's avatar
thomwolf committed
161
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            Pointer to the input tokens Embeddings Module of the model
        """
        raise NotImplementedError

    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
        """ Save a model and its configuration file to a directory, so that it
177
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
thomwolf's avatar
thomwolf committed
178
        """
179
180
181
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
thomwolf's avatar
thomwolf committed
182
183
184
185
186
187
188

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
189
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
190
191
192

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
193
        r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
194
195
196
197
198
199
200
201
202
203
204
205
206

        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
207
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
208
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
209
210
211
212
213
                - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

214
215
216
            config: (`optional`) one of:
                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
thomwolf's avatar
thomwolf committed
217
218
219
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
220
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
221
222
223
224
225
226
227
228
229
230
231
232
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            from_pt: (`optional`) boolean, default False:
                Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

233
234
235
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
236
237
238
239
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

thomwolf's avatar
thomwolf committed
240
241
242
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

thomwolf's avatar
thomwolf committed
243
244
245
246
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
247
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
248
249
250

        Examples::

Lysandre's avatar
Lysandre committed
251
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
252
253
254
255
256
257
258
259
260
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

        """
261
262
263
264
265
266
267
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
thomwolf's avatar
thomwolf committed
268

269
270
271
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
272
            config, model_kwargs = cls.config_class.from_pretrained(
273
274
275
276
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
277
                force_download=force_download,
278
                resume_download=resume_download,
279
                **kwargs,
thomwolf's avatar
thomwolf committed
280
281
282
283
284
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
285
        if pretrained_model_name_or_path is not None:
thomwolf's avatar
thomwolf committed
286
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
287
288
289
290
291
292
293
294
295
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
296
297
298
299
300
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
301
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
302
                archive_file = pretrained_model_name_or_path
303
304
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
305
            else:
306
                archive_file = hf_bucket_url(pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME))
thomwolf's avatar
thomwolf committed
307
308
309

            # redirect to the cache, if necessary
            try:
310
311
312
313
314
315
316
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                )
thomwolf's avatar
thomwolf committed
317
318
            except EnvironmentError as e:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
319
                    logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
thomwolf's avatar
thomwolf committed
320
321
322
323
324
325
                else:
                    logger.error(
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url but couldn't find any file "
                        "associated to this path or url.".format(
                            pretrained_model_name_or_path,
326
327
328
329
                            ", ".join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                        )
                    )
thomwolf's avatar
thomwolf committed
330
331
332
                raise e
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
333
            else:
334
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
335
        else:
thomwolf's avatar
thomwolf committed
336
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
337
338
339
340
341
342

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
343
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
344

345
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
346

thomwolf's avatar
thomwolf committed
347
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
348
349
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
350
351
352
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
353
354
355
356
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
357

358
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
359

thomwolf's avatar
thomwolf committed
360
        # Check if the models are the same to output loading informations
361
362
363
364
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
thomwolf's avatar
thomwolf committed
365
366
367
368
369
370
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(missing_keys) > 0:
371
372
373
            logger.info(
                "Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
thomwolf's avatar
thomwolf committed
374
        if len(unexpected_keys) > 0:
375
376
377
            logger.info(
                "Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
thomwolf's avatar
thomwolf committed
378
        if len(error_msgs) > 0:
379
380
381
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
thomwolf's avatar
thomwolf committed
382
        if output_loading_info:
383
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
thomwolf's avatar
thomwolf committed
384
385
            return model, loading_info

thomwolf's avatar
thomwolf committed
386
        return model
thomwolf's avatar
WIP  
thomwolf committed
387

388

thomwolf's avatar
WIP  
thomwolf committed
389
class TFConv1D(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
390
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
thomwolf's avatar
WIP  
thomwolf committed
391
392
393
        """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
394
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
395
        self.nf = nf
thomwolf's avatar
thomwolf committed
396
        self.nx = nx
thomwolf's avatar
thomwolf committed
397
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
398
399
400

    def build(self, input_shape):
        self.weight = self.add_weight(
401
402
403
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
404

thomwolf's avatar
WIP  
thomwolf committed
405
    def call(self, x):
thomwolf's avatar
thomwolf committed
406
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
407

thomwolf's avatar
thomwolf committed
408
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
409
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
410
411

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
412

thomwolf's avatar
WIP  
thomwolf committed
413
        return x
thomwolf's avatar
thomwolf committed
414
415


thomwolf's avatar
thomwolf committed
416
417
418
class TFSharedEmbeddings(tf.keras.layers.Layer):
    """Construct shared token embeddings.
    """
419

thomwolf's avatar
thomwolf committed
420
    def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
421
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
422
423
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
424
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
425
426
427
428
429
430
431

    def build(self, input_shape):
        """Build shared word embedding layer
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
432
433
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
434
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
435
436
437
438
439
440
441
442
443
444
445
446

    def call(self, inputs, mode="embedding"):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
447

thomwolf's avatar
thomwolf committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [..., hidden_size]
            Returns:
                float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
class TFSequenceSummary(tf.keras.layers.Layer):
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
    """
492

thomwolf's avatar
thomwolf committed
493
    def __init__(self, config, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
494
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
495

496
497
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
498
499
500
501
502
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

503
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
504
        if self.has_summary:
505
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
506
507
508
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
509
510
511
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
512

513
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
514
        if self.has_activation:
515
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
516

517
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
518
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
519
520
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

521
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
522
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

    def call(self, inputs, training=False):
        """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
                if summary_type == 'cls_index' and cls_index is None:
                    we take the last token of the sequence as classification token
        """
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
            cls_index = None
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
540
            hidden_states = inputs.get("hidden_states")
541
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
542

543
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
544
            output = hidden_states[:, -1]
545
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
546
            output = hidden_states[:, 0]
547
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
548
            output = tf.reduce_mean(hidden_states, axis=1)
549
        elif self.summary_type == "cls_index":
550
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
551
            if cls_index is None:
552
553
554
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
555
556
557
558
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
559
560
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
561
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
562
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
563
564
565
566
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
567
568
            raise NotImplementedError

569
570
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
571

572
        if self.has_summary:
573
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
574

575
        if self.has_activation:
thomwolf's avatar
thomwolf committed
576
577
            output = self.activation(output)

578
579
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
580
581
582

        return output

583

thomwolf's avatar
thomwolf committed
584
585
586
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
587
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
588
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
589

590

thomwolf's avatar
thomwolf committed
591
def get_initializer(initializer_range=0.02):
Julien Chaumond's avatar
Julien Chaumond committed
592
593
594
595
596
597
598
    """Creates a `tf.initializers.truncated_normal` with the given range.
    Args:
        initializer_range: float, initializer range for stddev.
    Returns:
        TruncatedNormal initializer with stddev = `initializer_range`.
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)