modeling_tf_utils.py 63.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""


import logging
import os

Aymeric Augustin's avatar
Aymeric Augustin committed
22
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
23
import numpy as np
thomwolf's avatar
thomwolf committed
24
import tensorflow as tf
thomwolf's avatar
thomwolf committed
25
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
26
27

from .configuration_utils import PretrainedConfig
28
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
29
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
thomwolf committed
30

Aymeric Augustin's avatar
Aymeric Augustin committed
31

thomwolf's avatar
thomwolf committed
32
33
logger = logging.getLogger(__name__)

34

35
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    """
    A few utilities for `tf.keras.Model`s, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the model.
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


50
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
thomwolf's avatar
thomwolf committed
51
52
    r""" Base class for all TF models.

53
        :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
54
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
55
56

        Class attributes (overridden by derived classes):
57
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
thomwolf's avatar
thomwolf committed
58
59
60
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

61
62
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
thomwolf's avatar
thomwolf committed
63
64
65
66
67
68
69
70
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
    """
    config_class = None
    pretrained_model_archive_map = {}
    base_model_prefix = ""

71
72
73
74
75
76
77
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
78
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
79
80

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
81
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
82
83
84
85
86
87
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
88
89
                )
            )
thomwolf's avatar
thomwolf committed
90
91
92
        # Save config in model
        self.config = config

93
    def get_input_embeddings(self):
94
95
96
97
98
99
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping vocabulary to hidden states.
100
101
102
103
104
105
106
107
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

    def get_output_embeddings(self):
108
109
110
111
112
113
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping hidden states to vocabulary.
114
115
116
        """
        return None  # Overwrite for models with output embeddings

thomwolf's avatar
thomwolf committed
117
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
118
        """ Build a resized Embedding Variable from a provided token Embedding Module.
thomwolf's avatar
thomwolf committed
119
120
121
122
123
124
125
126
127
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
thomwolf's avatar
thomwolf committed
128
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
129
130
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
131
132
        # if new_num_tokens is None:
        #     return old_embeddings
thomwolf's avatar
thomwolf committed
133

134
135
136
137
138
139
140
141
142
143
144
        # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        # if old_num_tokens == new_num_tokens:
        #     return old_embeddings

        # # Build new embeddings
        # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        # new_embeddings.to(old_embeddings.weight.device)

        # # initialize all new embeddings (in particular added tokens)
        # self._init_weights(new_embeddings)

145
        # # Copy token embeddings from the previous weights
146
147
148
149
        # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        # return new_embeddings
thomwolf's avatar
thomwolf committed
150
151
152
153
154
155
156
157

    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

        Arguments:

            new_num_tokens: (`optional`) int:
158
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
thomwolf's avatar
thomwolf committed
159
                If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
thomwolf's avatar
thomwolf committed
160

thomwolf's avatar
thomwolf committed
161
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            Pointer to the input tokens Embeddings Module of the model
        """
        raise NotImplementedError

    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
        """ Save a model and its configuration file to a directory, so that it
177
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
thomwolf's avatar
thomwolf committed
178
        """
179
180
181
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
thomwolf's avatar
thomwolf committed
182
183
184
185
186
187
188

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
189
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
190
191
192

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
193
        r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
194
195
196
197
198
199
200
201
202
203

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
204
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
205
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
206
207
208
209
210
                - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

211
212
213
            config: (`optional`) one of:
                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
thomwolf's avatar
thomwolf committed
214
215
216
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
217
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
218
219
220
221
222
223
224
225
226
227
228
229
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            from_pt: (`optional`) boolean, default False:
                Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

230
231
232
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
233
234
235
236
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

thomwolf's avatar
thomwolf committed
237
238
239
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

thomwolf's avatar
thomwolf committed
240
241
242
243
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
244
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
245
246
247

        Examples::

Lysandre's avatar
Lysandre committed
248
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
249
250
251
252
253
254
255
256
257
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

        """
258
259
260
261
262
263
264
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
thomwolf's avatar
thomwolf committed
265

266
267
268
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
269
            config, model_kwargs = cls.config_class.from_pretrained(
270
271
272
273
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
274
                force_download=force_download,
275
                resume_download=resume_download,
276
                **kwargs,
thomwolf's avatar
thomwolf committed
277
278
279
280
281
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
282
        if pretrained_model_name_or_path is not None:
thomwolf's avatar
thomwolf committed
283
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
284
285
286
287
288
289
290
291
292
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
293
294
295
296
297
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
298
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
299
                archive_file = pretrained_model_name_or_path
300
301
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
302
            else:
thomwolf's avatar
thomwolf committed
303
304
305
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)
                )
thomwolf's avatar
thomwolf committed
306
307
308

            # redirect to the cache, if necessary
            try:
309
310
311
312
313
314
315
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                )
thomwolf's avatar
thomwolf committed
316
317
            except EnvironmentError as e:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
318
                    logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
thomwolf's avatar
thomwolf committed
319
320
321
322
323
324
                else:
                    logger.error(
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url but couldn't find any file "
                        "associated to this path or url.".format(
                            pretrained_model_name_or_path,
325
326
327
328
                            ", ".join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                        )
                    )
thomwolf's avatar
thomwolf committed
329
330
331
                raise e
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
332
            else:
333
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
334
        else:
thomwolf's avatar
thomwolf committed
335
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
336
337
338
339
340
341

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
342
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
343

344
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
345

thomwolf's avatar
thomwolf committed
346
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
347
348
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
349
350
351
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
352
353
354
355
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
356

357
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
358

thomwolf's avatar
thomwolf committed
359
        # Check if the models are the same to output loading informations
360
361
362
363
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
thomwolf's avatar
thomwolf committed
364
365
366
367
368
369
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(missing_keys) > 0:
370
371
372
            logger.info(
                "Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
thomwolf's avatar
thomwolf committed
373
        if len(unexpected_keys) > 0:
374
375
376
            logger.info(
                "Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
thomwolf's avatar
thomwolf committed
377
        if len(error_msgs) > 0:
378
379
380
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
thomwolf's avatar
thomwolf committed
381
        if output_loading_info:
382
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
thomwolf's avatar
thomwolf committed
383
384
            return model, loading_info

thomwolf's avatar
thomwolf committed
385
        return model
thomwolf's avatar
WIP  
thomwolf committed
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
    def prepare_inputs_for_generation(self, inputs, **kwargs):
        return {"inputs": inputs}

    def _do_output_past(self, outputs):
        has_output_past = hasattr(self.config, "output_past") and self.config.output_past
        has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len

        if has_output_past and not has_mem_len and len(outputs) > 1:
            return True
        elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
            return True

        return False

    def generate(
        self,
        input_ids=None,
        max_length=None,
        do_sample=True,
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
        bos_token_id=None,
        pad_token_id=None,
        eos_token_ids=None,
        length_penalty=None,
        num_return_sequences=None,
    ):
        r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
        and beam-search.

        Adapted in part from `Facebook's XLM beam search code`_.

        .. _`Facebook's XLM beam search code`:
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


        Parameters:

            input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
                The sequence used as a prompt for the generation. If `None` the method initializes
                it as an empty `torch.LongTensor` of shape `(1,)`.

            max_length: (`optional`) int
                The max length of the sequence to be generated.  Between 1 and infinity. Default to 20.

            do_sample: (`optional`) bool
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.

            num_beams: (`optional`) int
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

            temperature: (`optional`) float
                The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.

            top_k: (`optional`) int
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.

            top_p: (`optional`) float
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

            repetition_penalty: (`optional`) float
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

            bos_token_id: (`optional`) int
                Beginning of sentence token if no prompt is provided. Default to 0.

            eos_token_ids: (`optional`) int or list of int
                End of sequence token or list of tokens to stop the generation. Default to 0.
            length_penalty: (`optional`) float
                Exponential penalty to the length. Default to 1.

            num_return_sequences: (`optional`) int
                The number of independently computed returned sequences for each element in the batch. Default to 1.

        Return:

            output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

        Examples::

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False)  # do greedy decoding
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
            input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0)  # encode input context
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
            input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0)  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3)  # 3 generate sequences using by sampling
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
            input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0)  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

        """

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
            raise AttributeError(
                "You tried to generate sequences with a model that does not have a LM Head."
                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
            )

        max_length = max_length if max_length is not None else self.config.max_length
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )

        if input_ids is not None:
            batch_size = shape_list(input_ids)[0]  # overriden by the input batch_size
        else:
            batch_size = 1
        if isinstance(eos_token_ids, int):
            eos_token_ids = [eos_token_ids]

        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
        assert temperature > 0, "`temperature` should be strictely positive."
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
        assert input_ids is not None or (
            isinstance(bos_token_id, int) and bos_token_id >= 0
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
        assert pad_token_id is None or (
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
        ), "`pad_token_id` should be a positive integer."
        assert (eos_token_ids is None) or (
            isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
        ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
        assert length_penalty > 0, "`length_penalty` should be strictely positive."
        assert (
            isinstance(num_return_sequences, int) and num_return_sequences > 0
        ), "`num_return_sequences` should be a strictely positive integer."

        if input_ids is None:
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
                "you should either supply a context to complete as `input_ids` input "
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
            )
            input_ids = tf.fill((batch_size, 1), bos_token_id)
        else:
            assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."

560
561
562
563
564
565
566
567
568
569
570
571
572
        if do_sample is False:
            if num_beams == 1:
                # no_beam_search greedy generation conditions
                assert (
                    num_return_sequences == 1
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

            else:
                # beam_search greedy generation conditions
                assert (
                    num_beams >= num_return_sequences
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

573
574
575
576
577
578
579
580
581
582
        if pad_token_id is None and eos_token_ids is not None:
            logger.warning(
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
            )
            pad_token_id = eos_token_ids[0]

        # current position and vocab size
        cur_len = shape_list(input_ids)[1]
        vocab_size = self.config.vocab_size

583
        if num_return_sequences != 1 and do_sample:
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
            # Expand input to num return sequences
            input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_return_sequences, cur_len))
            effective_batch_size = batch_size * num_return_sequences
            input_ids = tf.reshape(input_ids, (effective_batch_size, cur_len))
        else:
            effective_batch_size = batch_size

        if num_beams > 1:
            output = self._generate_beam_search(
                input_ids,
                cur_len,
                max_length,
                do_sample,
                temperature,
                top_k,
                top_p,
                repetition_penalty,
                pad_token_id,
                eos_token_ids,
                effective_batch_size,
604
                num_return_sequences,
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
                length_penalty,
                num_beams,
                vocab_size,
            )
        else:
            output = self._generate_no_beam_search(
                input_ids,
                cur_len,
                max_length,
                do_sample,
                temperature,
                top_k,
                top_p,
                repetition_penalty,
                pad_token_id,
                eos_token_ids,
                effective_batch_size,
            )

        return output

    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        pad_token_id,
        eos_token_ids,
        batch_size,
    ):
        """ Generate sequences for each example without beam search (num_beams == 1).
            All returned sequence are generated independantly.
        """

644
        # length of generated sentences / unfinished sentences
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        unfinished_sents = tf.ones_like(input_ids[:, 0])
        sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length

        past = None

        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
            outputs = self(**model_inputs)
            next_token_logits = outputs[0][:, -1, :]

            # if model has past, then set the past variable to speed up decoding
            if self._do_output_past(outputs):
                past = outputs[1]

            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
661
662
663
                next_token_logits_penalties = _create_next_token_logits_penalties(
                    input_ids, next_token_logits, repetition_penalty
                )
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)

            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
                next_token = tf.squeeze(
                    tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
                )
            else:
                # Greedy decoding
                next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)

            # update generations and finished sentences
            if eos_token_ids is not None:
                # pad finished sentences if eos_token_ids exist
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

            input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)

            if eos_token_ids is not None:
                for eos_token_id in eos_token_ids:
                    eos_in_sents = tokens_to_add == eos_token_id
                    # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                    is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
                        unfinished_sents, tf.cast(eos_in_sents, tf.int32)
                    )
                    sent_lengths = (
                        sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
                        + cur_len * is_sents_unfinished_and_token_to_add_is_eos
                    )

                    # unfinished_sents is set to zero if eos in sentence
                    unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos

            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if tf.math.reduce_max(unfinished_sents) == 0:
                break

        # if there are different sentences lengths in the batch, some batches have to be padded
        min_sent_length = tf.math.reduce_min(sent_lengths)
        max_sent_length = tf.math.reduce_max(sent_lengths)
        if min_sent_length != max_sent_length:
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
            # finished sents are filled with pad_token
            padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id

            # create length masks for tf.where operation
            broad_casted_sent_lengths = tf.broadcast_to(
                tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
            )
            broad_casted_range = tf.transpose(
                tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
            )

            decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
        else:
            decoded = input_ids

        return decoded

    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        pad_token_id,
        eos_token_ids,
        batch_size,
745
        num_return_sequences,
746
747
748
749
        length_penalty,
        num_beams,
        vocab_size,
    ):
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
        """ Generate sequences for each example with beam search.
        """

        # Expand input to num beams
        input_ids = tf.broadcast_to(tf.expand_dims(input_ids, 1), (batch_size, num_beams, cur_len))
        input_ids = tf.reshape(input_ids, (batch_size * num_beams, cur_len))  # (batch_size * num_beams, cur_len)

        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
        ]

        # scores for each sentence in the beam
        beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
        beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
        beam_scores = tf.reshape(tf.concat([beam_scores_begin, beam_scores_end], -1), (batch_size * num_beams,))

        # cache compute states
        past = None

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)

            # if model has past, then set the past variable to speed up decoding
            if self._do_output_past(outputs):
                past = outputs[1]

            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
784
785
786
                next_token_logits_penalties = _create_next_token_logits_penalties(
                    input_ids, next_token_logits, repetition_penalty
                )
787
788
789
790
791
792
793
794
795
796
797
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)

            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = tf_top_k_top_p_filtering(
                    next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
                )  # (batch_size * num_beams, vocab_size)
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
798
799
800
                next_tokens = tf.random.categorical(
                    next_token_logits, dtype=tf.int32, num_samples=2
                )  # (batch_size * num_beams, vocab_size)
801
802
803
                # Compute next scores
                scores = tf.nn.log_softmax(next_token_logits, axis=-1)  # (batch_size * num_beams, vocab_size)
                _scores = tf.gather(scores, next_tokens, batch_dims=1)  # (batch_size * num_beams, 2)
804
805
806
                next_scores = _scores + tf.broadcast_to(
                    beam_scores[:, None], (batch_size * num_beams, 2)
                )  # (batch_size * num_beams, 2)
807
808
809
810
811
812
813
814
                # Match shape of greedy beam search
                next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams))  # (batch_size, 2 * num_beams)
                next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams))  # (batch_size, 2 * num_beams)
            else:
                # do greedy beam search
                scores = tf.nn.log_softmax(next_token_logits, axis=-1)  # (batch_size * num_beams, vocab_size)
                assert shape_list(scores) == [batch_size * num_beams, vocab_size]
                # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
815
816
817
                next_scores = scores + tf.broadcast_to(
                    beam_scores[:, None], (batch_size * num_beams, vocab_size)
                )  # (batch_size * num_beams, vocab_size)
818
819

                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
820
821
822
                next_scores = tf.reshape(
                    next_scores, (batch_size, num_beams * vocab_size)
                )  # (batch_size, num_beams * vocab_size)
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
                next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)

            assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]

            # next batch beam content
            # list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for batch_idx in range(batch_size):

                # if we are done with this sentence
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                    tf.reduce_max(next_scores[batch_idx]).numpy()
                )
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                    assert (
                        eos_token_ids is not None and pad_token_id is not None
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next tokens for this sentence
                for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):

                    # get beam and token IDs
                    beam_id = idx // vocab_size
                    token_id = idx % vocab_size

                    # add to generated hypotheses if end of sentence or last iteration
                    if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
                        generated_hyps[batch_idx].add(
                            tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
                        )
                    else:
                        # add next predicted token if it is not eos_token
                        next_sent_beam.append((score, token_id, batch_idx * num_beams + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

                # update next beam content
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
            beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
            beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)

            # re-order batch
            input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
            input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)

            # re-order internal states
            if past:
                past = self._reorder_cache(past, beam_idx)

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        for batch_idx in range(batch_size):
            # Add all open beam hypothesis to generated_hyps
            if not done[batch_idx]:
                for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):

                    # get beam and token IDs
                    beam_id = idx // vocab_size
                    token_id = idx % vocab_size
                    generated_hyps[batch_idx].add(
                        tf.identity(input_ids[batch_idx * num_beams + beam_id, :cur_len]), score.numpy()
                    )

        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

        # select the best hypotheses
        sent_lengths_list = []
        best = []

        # retrieve best hypotheses
        for i, hypotheses in enumerate(generated_hyps):
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths_list.append(len(best_hyp))
                best.append(best_hyp)
924
925
926
        assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
            output_batch_size, len(best)
        )
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941

        sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)

        # shorter batches are filled with pad_token
        if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
            sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
            decoded_list = []

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
                padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
                decoded_hypo = tf.concat([hypo, padding], axis=0)

                if sent_lengths[i] < max_length:
942
943
944
945
946
                    decoded_hypo = tf.where(
                        tf.range(max_length) == sent_lengths[i],
                        eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
                        decoded_hypo,
                    )
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
                decoded_list.append(decoded_hypo)
            decoded = tf.stack(decoded_list)
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = tf.stack(best)

        return decoded

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = []
        for layer_past in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` and `mems` is at 2nd position
962
963
            reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
            reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
            # check that shape matches
            assert shape_list(reordered_layer_past) == shape_list(layer_past)
            reordered_past.append(reordered_layer_past)
        past = tuple(reordered_past)
        return past


def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
    # create logit penalties for already seen input_ids
    token_penalties = np.ones(shape_list(logits))
    prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
    for i, prev_input_id in enumerate(prev_input_ids):
        logit_penalized = logits[i].numpy()[prev_input_id]
        # if previous logit score is < 0 then multiply repetition penalty else divide
        logit_penalized[logit_penalized < 0] = repetition_penalty
        logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
        np.put(token_penalties[i], prev_input_id, logit_penalized)
    return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048


def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            Make sure we keep at least min_tokens_to_keep per batch example in the output
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    logits_shape = shape_list(logits)

    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1])  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)

    if top_p < 1.0:
        sorted_indices = tf.argsort(logits, direction="DESCENDING")
        sorted_logits = tf.gather(
            logits, sorted_indices, axis=-1, batch_dims=1
        )  # expects logits to be of dim (batch_size, vocab_size)

        cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p

        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove = tf.concat(
                [
                    tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
                    sorted_indices_to_remove[:, min_tokens_to_keep:],
                ],
                -1,
            )

        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
        sorted_indices_to_remove = tf.concat(
            [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
        )
        # scatter sorted tensors to original indexing
        indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
    return logits


def scatter_values_on_batch_indices(values, batch_indices):
    shape = shape_list(batch_indices)
    # broadcast batch dim to shape
    broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
    # transform batch_indices to pair_indices
    pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
    # scatter values to pair indices
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)


def set_tensor_by_indices_to_value(tensor, indices, value):
    # create value_tensor since tensor value assignment is not possible in TF
    value_tensor = tf.zeros_like(tensor) + value
    return tf.where(indices, value_tensor, tensor)

1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """

        if len(self) < self.num_beams:
            return False
        elif self.early_stopping:
            return True
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret


thomwolf's avatar
WIP  
thomwolf committed
1100
class TFConv1D(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
1101
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
thomwolf's avatar
WIP  
thomwolf committed
1102
1103
1104
        """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
1105
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
1106
        self.nf = nf
thomwolf's avatar
thomwolf committed
1107
        self.nx = nx
thomwolf's avatar
thomwolf committed
1108
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
1109
1110
1111

    def build(self, input_shape):
        self.weight = self.add_weight(
1112
1113
1114
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
1115

thomwolf's avatar
WIP  
thomwolf committed
1116
    def call(self, x):
thomwolf's avatar
thomwolf committed
1117
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
1118

thomwolf's avatar
thomwolf committed
1119
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
1120
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
1121
1122

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
1123

thomwolf's avatar
WIP  
thomwolf committed
1124
        return x
thomwolf's avatar
thomwolf committed
1125
1126


thomwolf's avatar
thomwolf committed
1127
1128
1129
class TFSharedEmbeddings(tf.keras.layers.Layer):
    """Construct shared token embeddings.
    """
1130

thomwolf's avatar
thomwolf committed
1131
    def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1132
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1133
1134
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
1135
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
1136
1137

    def build(self, input_shape):
1138
        """Build shared token embedding layer
thomwolf's avatar
thomwolf committed
1139
1140
1141
1142
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
1143
1144
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
1145
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

    def call(self, inputs, mode="embedding"):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
1158

thomwolf's avatar
thomwolf committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [..., hidden_size]
            Returns:
                float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
class TFSequenceSummary(tf.keras.layers.Layer):
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
    """
1203

thomwolf's avatar
thomwolf committed
1204
    def __init__(self, config, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1205
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1206

1207
1208
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1209
1210
1211
1212
1213
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

1214
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
1215
        if self.has_summary:
1216
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
1217
1218
1219
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
1220
1221
1222
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
1223

1224
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
1225
        if self.has_activation:
1226
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
1227

1228
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
1229
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
1230
1231
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

1232
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
1233
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

    def call(self, inputs, training=False):
        """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
                if summary_type == 'cls_index' and cls_index is None:
                    we take the last token of the sequence as classification token
        """
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
            cls_index = None
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
1251
            hidden_states = inputs.get("hidden_states")
1252
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
1253

1254
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
1255
            output = hidden_states[:, -1]
1256
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
1257
            output = hidden_states[:, 0]
1258
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
1259
            output = tf.reduce_mean(hidden_states, axis=1)
1260
        elif self.summary_type == "cls_index":
1261
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
1262
            if cls_index is None:
1263
1264
1265
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
1266
1267
1268
1269
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
1270
1271
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1272
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1273
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
1274
1275
1276
1277
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1278
1279
            raise NotImplementedError

1280
1281
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1282

1283
        if self.has_summary:
1284
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1285

1286
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1287
1288
            output = self.activation(output)

1289
1290
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1291
1292
1293

        return output

1294

thomwolf's avatar
thomwolf committed
1295
1296
1297
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
1298
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
1299
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1300

1301

thomwolf's avatar
thomwolf committed
1302
def get_initializer(initializer_range=0.02):
Julien Chaumond's avatar
Julien Chaumond committed
1303
1304
1305
1306
1307
1308
1309
    """Creates a `tf.initializers.truncated_normal` with the given range.
    Args:
        initializer_range: float, initializer range for stddev.
    Returns:
        TruncatedNormal initializer with stddev = `initializer_range`.
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)