modeling_tf_utils.py 82.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
17
import functools
thomwolf's avatar
thomwolf committed
18
19
20
import logging
import os

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import h5py
Julien Chaumond's avatar
Julien Chaumond committed
22
import numpy as np
thomwolf's avatar
thomwolf committed
23
import tensorflow as tf
thomwolf's avatar
thomwolf committed
24
from tensorflow.python.keras.saving import hdf5_format
thomwolf's avatar
thomwolf committed
25
26

from .configuration_utils import PretrainedConfig
27
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
28
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
thomwolf committed
29

Aymeric Augustin's avatar
Aymeric Augustin committed
30

thomwolf's avatar
thomwolf committed
31
32
logger = logging.getLogger(__name__)

33

34
class TFModelUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    """
    A few utilities for `tf.keras.Model`s, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the model.
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


49
def keras_serializable(cls):
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
    1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
       serialization time
    2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
       convert it to a config object for the actual layer initializer
    3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
       not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`

    :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
                `TF*MainLayer` class in this project)
    :return: the same class object, with modifications for Keras deserialization.
    """
65
    initializer = cls.__init__
66

67
68
69
70
    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

71
    @functools.wraps(initializer)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def wrapped_init(self, *args, **kwargs):
        transformers_config = kwargs.pop("transformers_config", None)
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
        if config is not None and transformers_config is not None:
            raise ValueError("Must pass either `config` or `transformers_config`, not both")
        elif config is not None:
            # normal layer construction, call with unchanged args (config is already in there)
            initializer(self, *args, **kwargs)
        elif transformers_config is not None:
            # Keras deserialization, convert dict to config
            config = config_class.from_dict(transformers_config)
            initializer(self, config, *args, **kwargs)
        else:
            raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
86
87
        self._transformers_config = config

88
89
90
91
92
93
94
95
    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
96
            cfg["transformers_config"] = self._transformers_config.to_dict()
97
98
99
100
            return cfg

        cls.get_config = get_config

101
    cls._keras_serializable = True
102
103
104
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls
105
106


107
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
thomwolf's avatar
thomwolf committed
108
109
    r""" Base class for all TF models.

110
        :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
111
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
thomwolf's avatar
thomwolf committed
112
113

        Class attributes (overridden by derived classes):
114
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
thomwolf's avatar
thomwolf committed
115
116
117
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

118
119
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
thomwolf's avatar
thomwolf committed
120
121
122
123
124
125
126
127
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
    """
    config_class = None
    pretrained_model_archive_map = {}
    base_model_prefix = ""

128
129
130
131
132
133
134
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
135
        return {"input_ids": tf.constant(DUMMY_INPUTS)}
thomwolf's avatar
thomwolf committed
136
137

    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
138
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
139
140
141
142
143
144
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
145
146
                )
            )
thomwolf's avatar
thomwolf committed
147
148
149
        # Save config in model
        self.config = config

150
    def get_input_embeddings(self):
151
152
153
154
155
156
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping vocabulary to hidden states.
157
158
159
160
161
162
163
164
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError

    def get_output_embeddings(self):
165
166
167
168
169
170
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`tf.keras.layers.Layer`:
                A torch module mapping hidden states to vocabulary.
171
172
173
        """
        return None  # Overwrite for models with output embeddings

thomwolf's avatar
thomwolf committed
174
    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
175
        """ Build a resized Embedding Variable from a provided token Embedding Module.
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
182
183
184
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
thomwolf's avatar
thomwolf committed
185
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
186
187
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
188
189
        # if new_num_tokens is None:
        #     return old_embeddings
thomwolf's avatar
thomwolf committed
190

191
192
193
194
195
196
197
198
199
200
201
        # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        # if old_num_tokens == new_num_tokens:
        #     return old_embeddings

        # # Build new embeddings
        # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        # new_embeddings.to(old_embeddings.weight.device)

        # # initialize all new embeddings (in particular added tokens)
        # self._init_weights(new_embeddings)

202
        # # Copy token embeddings from the previous weights
203
204
205
206
        # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        # return new_embeddings
thomwolf's avatar
thomwolf committed
207
208
209
210
211
212
213
214

    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

        Arguments:

            new_num_tokens: (`optional`) int:
215
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
thomwolf's avatar
thomwolf committed
216
                If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
thomwolf's avatar
thomwolf committed
217

thomwolf's avatar
thomwolf committed
218
        Return: ``tf.Variable``
thomwolf's avatar
thomwolf committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            Pointer to the input tokens Embeddings Module of the model
        """
        raise NotImplementedError

    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
        """
        raise NotImplementedError

    def save_pretrained(self, save_directory):
        """ Save a model and its configuration file to a directory, so that it
LysandreJik's avatar
LysandreJik committed
234
            can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.
thomwolf's avatar
thomwolf committed
235
        """
236
237
238
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
thomwolf's avatar
thomwolf committed
239
240
241
242
243
244
245

        # Save configuration file
        self.config.save_pretrained(save_directory)

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
        self.save_weights(output_model_file)
thomwolf's avatar
thomwolf committed
246
        logger.info("Model weights saved in {}".format(output_model_file))
thomwolf's avatar
thomwolf committed
247
248
249

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
250
        r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
thomwolf's avatar
thomwolf committed
251
252
253
254
255
256
257
258
259
260

        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
261
                - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
262
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
263
264
265
266
267
                - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

268
269
270
            config: (`optional`) one of:
                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                    - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
thomwolf's avatar
thomwolf committed
271
272
273
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
274
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
thomwolf's avatar
thomwolf committed
275
276
277
278
279
280
281
282
283
284
285
286
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            from_pt: (`optional`) boolean, default False:
                Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

287
288
289
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
290
291
292
293
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

thomwolf's avatar
thomwolf committed
294
295
296
            output_loading_info: (`optional`) boolean:
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.

thomwolf's avatar
thomwolf committed
297
298
299
300
            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
301
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
thomwolf's avatar
thomwolf committed
302
303
304

        Examples::

Lysandre's avatar
Lysandre committed
305
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
306
307
308
309
310
311
312
313
314
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)

        """
315
316
317
318
319
320
321
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
thomwolf's avatar
thomwolf committed
322

323
324
325
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
326
            config, model_kwargs = cls.config_class.from_pretrained(
327
328
329
330
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
thomwolf's avatar
thomwolf committed
331
                force_download=force_download,
332
                resume_download=resume_download,
333
                **kwargs,
thomwolf's avatar
thomwolf committed
334
335
336
337
338
            )
        else:
            model_kwargs = kwargs

        # Load model
thomwolf's avatar
thomwolf committed
339
        if pretrained_model_name_or_path is not None:
thomwolf's avatar
thomwolf committed
340
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
341
342
343
344
345
346
347
348
349
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
                if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
350
351
352
353
354
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_pt` set to False".format(
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
                        )
                    )
Julien Chaumond's avatar
Julien Chaumond committed
355
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
356
                archive_file = pretrained_model_name_or_path
357
358
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                archive_file = pretrained_model_name_or_path + ".index"
thomwolf's avatar
thomwolf committed
359
            else:
thomwolf's avatar
thomwolf committed
360
361
362
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)
                )
thomwolf's avatar
thomwolf committed
363
364
365

            # redirect to the cache, if necessary
            try:
366
367
368
369
370
371
372
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                )
thomwolf's avatar
thomwolf committed
373
374
            except EnvironmentError as e:
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
375
                    logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
thomwolf's avatar
thomwolf committed
376
377
378
379
380
381
                else:
                    logger.error(
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url but couldn't find any file "
                        "associated to this path or url.".format(
                            pretrained_model_name_or_path,
382
383
384
385
                            ", ".join(cls.pretrained_model_archive_map.keys()),
                            archive_file,
                        )
                    )
thomwolf's avatar
thomwolf committed
386
387
388
                raise e
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
thomwolf's avatar
thomwolf committed
389
            else:
390
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
thomwolf's avatar
thomwolf committed
391
        else:
thomwolf's avatar
thomwolf committed
392
            resolved_archive_file = None
thomwolf's avatar
thomwolf committed
393
394
395
396
397
398

        # Instantiate model.
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
399
            return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
thomwolf's avatar
thomwolf committed
400

401
        model(model.dummy_inputs, training=False)  # build the network with dummy inputs
thomwolf's avatar
thomwolf committed
402

thomwolf's avatar
thomwolf committed
403
        assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
thomwolf's avatar
thomwolf committed
404
405
        # 'by_name' allow us to do transfer learning by skipping/adding layers
        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
406
407
408
        try:
            model.load_weights(resolved_archive_file, by_name=True)
        except OSError:
409
410
411
412
            raise OSError(
                "Unable to load weights from h5 file. "
                "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
            )
thomwolf's avatar
thomwolf committed
413

414
        model(model.dummy_inputs, training=False)  # Make sure restore ops are run
thomwolf's avatar
thomwolf committed
415

thomwolf's avatar
thomwolf committed
416
        # Check if the models are the same to output loading informations
417
418
419
420
        with h5py.File(resolved_archive_file, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
thomwolf's avatar
thomwolf committed
421
422
423
424
425
426
        model_layer_names = set(layer.name for layer in model.layers)
        missing_keys = list(model_layer_names - hdf5_layer_names)
        unexpected_keys = list(hdf5_layer_names - model_layer_names)
        error_msgs = []

        if len(missing_keys) > 0:
427
428
429
            logger.info(
                "Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
thomwolf's avatar
thomwolf committed
430
        if len(unexpected_keys) > 0:
431
432
433
            logger.info(
                "Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
thomwolf's avatar
thomwolf committed
434
        if len(error_msgs) > 0:
435
436
437
            raise RuntimeError(
                "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
thomwolf's avatar
thomwolf committed
438
        if output_loading_info:
439
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
thomwolf's avatar
thomwolf committed
440
441
            return model, loading_info

thomwolf's avatar
thomwolf committed
442
        return model
thomwolf's avatar
WIP  
thomwolf committed
443

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    def prepare_inputs_for_generation(self, inputs, **kwargs):
        return {"inputs": inputs}

    def _do_output_past(self, outputs):
        has_output_past = hasattr(self.config, "output_past") and self.config.output_past
        has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len

        if has_output_past and not has_mem_len and len(outputs) > 1:
            return True
        elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
            return True

        return False

    def generate(
        self,
        input_ids=None,
        max_length=None,
462
        min_length=None,
463
464
        do_sample=None,
        early_stopping=None,
465
466
467
468
469
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
470
        bad_words_ids=None,
471
472
        bos_token_id=None,
        pad_token_id=None,
473
        eos_token_id=None,
474
        length_penalty=None,
475
        no_repeat_ngram_size=None,
476
        num_return_sequences=None,
477
        attention_mask=None,
478
        decoder_start_token_id=None,
479
480
481
482
483
484
485
486
487
488
489
490
    ):
        r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
        and beam-search.

        Adapted in part from `Facebook's XLM beam search code`_.

        .. _`Facebook's XLM beam search code`:
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


        Parameters:

491
            input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)`
492
493
494
495
496
497
                The sequence used as a prompt for the generation. If `None` the method initializes
                it as an empty `torch.LongTensor` of shape `(1,)`.

            max_length: (`optional`) int
                The max length of the sequence to be generated.  Between 1 and infinity. Default to 20.

498
499
            min_length: (`optional`) int
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
500
            do_sample: (`optional`) bool
501
502
503
504
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

            early_stopping: (`optional`) bool
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

            num_beams: (`optional`) int
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

            temperature: (`optional`) float
                The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.

            top_k: (`optional`) int
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.

            top_p: (`optional`) float
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

            repetition_penalty: (`optional`) float
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

            bos_token_id: (`optional`) int
522
                Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
523

524
525
526
            pad_token_id: (`optional`) int
                Pad token. Defaults to pad_token_id as defined in the models config.

527
528
            eos_token_id: (`optional`) int
                EOS token. Defaults to eos_token_id as defined in the models config.
529

530
531
532
            length_penalty: (`optional`) float
                Exponential penalty to the length. Default to 1.

533
534
535
            no_repeat_ngram_size: (`optional`) int
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.

536
537
538
            bad_words_ids: (`optional`) list of lists of int
                `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.

539
540
541
            num_return_sequences: (`optional`) int
                The number of independently computed returned sequences for each element in the batch. Default to 1.

542
543
544
545
546
547
            attention_mask (`optional`) obj: `tf.Tensor` with `dtype=tf.int32` of same shape as `input_ids`
                Mask to avoid performing attention on padding token indices.
                Mask values selected in ``[0, 1]``:
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
                Defaults to `None`.

LysandreJik's avatar
LysandreJik committed
548
                `What are attention masks? <../glossary.html#attention-mask>`__
549
550
551
552
553

            decoder_start_token_id=None: (`optional`) int
                If an encoder-decoder model starts decoding with a different token than BOS.
                Defaults to `None` and is changed to `BOS` later.

554
555
        Return:

556
            output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
557
558
559
560
561
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

        Examples::

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
562
563
            model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            outputs = model.generate(max_length=40)  # do greedy decoding
564
565
566
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
567
            model = TFAutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
568
            input_context = 'The dog'
569
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
570
571
572
573
574
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
575
            model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
576
            input_context = 'The dog'
577
578
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
579
580
581
582
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
583
            model = TFAutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
584
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
585
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
586
587
588
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

589
590
591
592
593
594
            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
            model = TFAutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
595
596
597
598
599
600
        """

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
            raise AttributeError(
                "You tried to generate sequences with a model that does not have a LM Head."
601
                "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
602
603
604
            )

        max_length = max_length if max_length is not None else self.config.max_length
605
        min_length = min_length if min_length is not None else self.config.min_length
606
        do_sample = do_sample if do_sample is not None else self.config.do_sample
607
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
608
609
610
611
612
613
614
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
615
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
616
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
617
618
619
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
620
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
621
622
623
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )
624
625
626
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
627
628
629
630
631
632
633

        if input_ids is not None:
            batch_size = shape_list(input_ids)[0]  # overriden by the input batch_size
        else:
            batch_size = 1

        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
634
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
635
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
636
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
637
638
639
640
641
642
643
644
645
646
647
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
        assert temperature > 0, "`temperature` should be strictely positive."
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
        assert input_ids is not None or (
            isinstance(bos_token_id, int) and bos_token_id >= 0
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
        assert pad_token_id is None or (
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
        ), "`pad_token_id` should be a positive integer."
648
649
650
        assert (eos_token_id is None) or (
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
        ), "`eos_token_id` should be a positive integer."
651
652
653
654
        assert length_penalty > 0, "`length_penalty` should be strictely positive."
        assert (
            isinstance(num_return_sequences, int) and num_return_sequences > 0
        ), "`num_return_sequences` should be a strictely positive integer."
655
656
657
        assert (
            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
658
659
660
661
662
663
664
665
666
667

        if input_ids is None:
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
                "you should either supply a context to complete as `input_ids` input "
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
            )
            input_ids = tf.fill((batch_size, 1), bos_token_id)
        else:
            assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."

668
        # not allow to duplicate outputs when greedy decoding
669
670
671
672
673
674
675
676
677
678
679
680
681
        if do_sample is False:
            if num_beams == 1:
                # no_beam_search greedy generation conditions
                assert (
                    num_return_sequences == 1
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

            else:
                # beam_search greedy generation conditions
                assert (
                    num_beams >= num_return_sequences
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

682
683
        # create attention mask if necessary
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
684
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
685
686
687
688
            attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
        elif attention_mask is None:
            attention_mask = tf.ones_like(input_ids)

689
        if pad_token_id is None and eos_token_id is not None:
690
            logger.warning(
691
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
692
            )
693
            pad_token_id = eos_token_id
694
695
696
697
698

        # current position and vocab size
        cur_len = shape_list(input_ids)[1]
        vocab_size = self.config.vocab_size

699
700
        # set effective batch size and effective batch multiplier according to do_sample
        if do_sample:
701
            effective_batch_size = batch_size * num_return_sequences
702
            effective_batch_mult = num_return_sequences
703
704
        else:
            effective_batch_size = batch_size
705
706
            effective_batch_mult = 1

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        if self.config.is_encoder_decoder:
            if decoder_start_token_id is None:
                decoder_start_token_id = bos_token_id

            assert (
                decoder_start_token_id is not None
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

            # get encoder and store encoder outputs
            encoder = self.get_encoder()

            encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

722
723
724
725
726
727
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
        if num_return_sequences > 1 or num_beams > 1:
            input_ids_len = shape_list(input_ids)[-1]
            input_ids = tf.broadcast_to(
                tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
            )
728
729
730
            attention_mask = tf.broadcast_to(
                tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
            )
731
732
733
            input_ids = tf.reshape(
                input_ids, (effective_batch_size * num_beams, input_ids_len)
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
734
735
736
            attention_mask = tf.reshape(
                attention_mask, (effective_batch_size * num_beams, input_ids_len)
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
737

738
739
740
741
742
743
        if self.config.is_encoder_decoder:

            # create empty decoder_input_ids
            input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
            cur_len = 1

744
745
746
747
748
749
750
751
752
753
754
755
            assert (
                batch_size == encoder_outputs[0].shape[0]
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
            expanded_batch_idxs = tf.reshape(
                tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
                shape=(-1,),
            )
            # expand encoder_outputs
            encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:])

756
757
758
759
        else:
            encoder_outputs = None
            cur_len = shape_list(input_ids)[-1]

760
761
762
        if num_beams > 1:
            output = self._generate_beam_search(
                input_ids,
763
764
765
766
767
768
769
770
771
772
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                early_stopping=early_stopping,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
773
                bad_words_ids=bad_words_ids,
774
                bos_token_id=bos_token_id,
775
                pad_token_id=pad_token_id,
776
                eos_token_id=eos_token_id,
777
                decoder_start_token_id=decoder_start_token_id,
778
779
780
781
782
                batch_size=effective_batch_size,
                num_return_sequences=num_return_sequences,
                length_penalty=length_penalty,
                num_beams=num_beams,
                vocab_size=vocab_size,
783
                encoder_outputs=encoder_outputs,
784
                attention_mask=attention_mask,
785
786
787
788
            )
        else:
            output = self._generate_no_beam_search(
                input_ids,
789
790
791
792
793
794
795
796
797
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
798
                bad_words_ids=bad_words_ids,
799
                bos_token_id=bos_token_id,
800
                pad_token_id=pad_token_id,
801
                eos_token_id=eos_token_id,
802
                decoder_start_token_id=decoder_start_token_id,
803
804
                batch_size=effective_batch_size,
                vocab_size=vocab_size,
805
                encoder_outputs=encoder_outputs,
806
                attention_mask=attention_mask,
807
808
809
810
811
812
813
814
815
            )

        return output

    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
816
        min_length,
817
818
819
820
821
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
822
        no_repeat_ngram_size,
823
        bad_words_ids,
824
        bos_token_id,
825
        pad_token_id,
826
        eos_token_id,
827
        decoder_start_token_id,
828
        batch_size,
829
        vocab_size,
830
        encoder_outputs,
831
        attention_mask,
832
833
834
835
836
    ):
        """ Generate sequences for each example without beam search (num_beams == 1).
            All returned sequence are generated independantly.
        """

837
        # length of generated sentences / unfinished sentences
838
839
840
        unfinished_sents = tf.ones_like(input_ids[:, 0])
        sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length

841
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models
842
843

        while cur_len < max_length:
844
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
845
846
847
848
849
850
851
852
853
            outputs = self(**model_inputs)
            next_token_logits = outputs[0][:, -1, :]

            # if model has past, then set the past variable to speed up decoding
            if self._do_output_past(outputs):
                past = outputs[1]

            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
854
855
856
                next_token_logits_penalties = _create_next_token_logits_penalties(
                    input_ids, next_token_logits, repetition_penalty
                )
857
858
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)

859
860
861
            if no_repeat_ngram_size > 0:
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
862
                banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
863
864
865
866
867
868
869
870
871
872
873
                # create banned_tokens boolean mask
                banned_tokens_indices_mask = []
                for banned_tokens_slice in banned_tokens:
                    banned_tokens_indices_mask.append(
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
                    )

                next_token_logits = set_tensor_by_indices_to_value(
                    next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
                )

874
875
876
877
878
879
880
881
882
883
884
885
886
887
            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

                banned_tokens_indices_mask = []
                for banned_tokens_slice in banned_tokens:
                    banned_tokens_indices_mask.append(
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
                    )

                next_token_logits = set_tensor_by_indices_to_value(
                    next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
                )

888
            # set eos token prob to zero if min_length is not reached
889
890
            if eos_token_id is not None and cur_len < min_length:
                # create eos_token_id boolean mask
891
                is_token_logit_eos_token = tf.convert_to_tensor(
892
                    [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
893
894
895
896
897
898
899
                )
                eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])

                next_token_logits = set_tensor_by_indices_to_value(
                    next_token_logits, eos_token_indices_mask, -float("inf")
                )

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
                next_token = tf.squeeze(
                    tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
                )
            else:
                # Greedy decoding
                next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)

            # update generations and finished sentences
915
916
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
917
918
919
920
921
922
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

            input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)

923
924
925
926
927
928
929
930
931
932
            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
                    unfinished_sents, tf.cast(eos_in_sents, tf.int32)
                )
                sent_lengths = (
                    sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
                    + cur_len * is_sents_unfinished_and_token_to_add_is_eos
                )
933

934
935
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
936
937
938
939
940

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if tf.math.reduce_max(unfinished_sents) == 0:
                break

941
942
943
944
945
946
            # extend attention_mask for new generated input if only decoder
            if self.config.is_encoder_decoder is False:
                attention_mask = tf.concat(
                    [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
                )

947
948
            cur_len = cur_len + 1

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        # if there are different sentences lengths in the batch, some batches have to be padded
        min_sent_length = tf.math.reduce_min(sent_lengths)
        max_sent_length = tf.math.reduce_max(sent_lengths)
        if min_sent_length != max_sent_length:
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
            # finished sents are filled with pad_token
            padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id

            # create length masks for tf.where operation
            broad_casted_sent_lengths = tf.broadcast_to(
                tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
            )
            broad_casted_range = tf.transpose(
                tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
            )

            decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
        else:
            decoded = input_ids

        return decoded

    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
976
        min_length,
977
        do_sample,
978
        early_stopping,
979
980
981
982
        temperature,
        top_k,
        top_p,
        repetition_penalty,
983
        no_repeat_ngram_size,
984
        bad_words_ids,
985
        bos_token_id,
986
        pad_token_id,
987
        decoder_start_token_id,
988
        eos_token_id,
989
        batch_size,
990
        num_return_sequences,
991
992
993
        length_penalty,
        num_beams,
        vocab_size,
994
        encoder_outputs,
995
        attention_mask,
996
    ):
997
998
999
1000
1001
        """ Generate sequences for each example with beam search.
        """

        # generated hypotheses
        generated_hyps = [
1002
1003
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
            for _ in range(batch_size)
1004
1005
        ]

1006
        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
1007
1008
        if do_sample is False:
            beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
1009
            beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
1010
1011
1012
1013
1014
            beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
        else:
            beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)

        beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
1015

1016
        # cache compute states
1017
        past = encoder_outputs
1018
1019
1020
1021
1022

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
1023
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
1024
1025
1026
1027
1028
1029
1030
1031
1032
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)

            # if model has past, then set the past variable to speed up decoding
            if self._do_output_past(outputs):
                past = outputs[1]

            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
1033
1034
1035
                next_token_logits_penalties = _create_next_token_logits_penalties(
                    input_ids, next_token_logits, repetition_penalty
                )
1036
1037
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)

1038
1039
1040
1041
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature

1042
            #             calculate log softmax score
1043
1044
1045
            scores = tf.nn.log_softmax(next_token_logits, axis=-1)  # (batch_size * num_beams, vocab_size)

            # set eos token prob to zero if min_length is not reached
1046
1047
            if eos_token_id is not None and cur_len < min_length:
                # create eos_token_id boolean mask
1048
1049
                num_batch_hypotheses = batch_size * num_beams

1050
                is_token_logit_eos_token = tf.convert_to_tensor(
1051
                    [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
1052
                )
1053
                eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
1054

1055
                scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
1056

1057
1058
1059
            if no_repeat_ngram_size > 0:
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1060
                num_batch_hypotheses = batch_size * num_beams
1061
1062
1063
                banned_tokens = calc_banned_ngram_tokens(
                    input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
                )
1064
1065
1066
1067
1068
1069
1070
                # create banned_tokens boolean mask
                banned_tokens_indices_mask = []
                for banned_tokens_slice in banned_tokens:
                    banned_tokens_indices_mask.append(
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
                    )

1071
1072
                scores = set_tensor_by_indices_to_value(
                    scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
1073
1074
                )

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

                banned_tokens_indices_mask = []
                for banned_tokens_slice in banned_tokens:
                    banned_tokens_indices_mask.append(
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
                    )

                scores = set_tensor_by_indices_to_value(
                    scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
                )

1089
            assert shape_list(scores) == [batch_size * num_beams, vocab_size]
1090

1091
            if do_sample:
1092
1093
1094
1095
                _scores = scores + tf.broadcast_to(
                    beam_scores[:, None], (batch_size * num_beams, vocab_size)
                )  # (batch_size * num_beams, vocab_size)

1096
                # Top-p/top-k filtering
1097
1098
                _scores = tf_top_k_top_p_filtering(
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
1099
1100
                )  # (batch_size * num_beams, vocab_size)
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
1101
1102
                _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))

1103
                next_tokens = tf.random.categorical(
1104
1105
                    _scores, dtype=tf.int32, num_samples=2 * num_beams
                )  # (batch_size, 2 * num_beams)
1106
                # Compute next scores
1107
                next_scores = tf.gather(_scores, next_tokens, batch_dims=1)  # (batch_size, 2 * num_beams)
1108
1109
1110
1111
1112

                # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
                next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
                next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
1113
1114
            else:
                # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
1115
1116
1117
                next_scores = scores + tf.broadcast_to(
                    beam_scores[:, None], (batch_size * num_beams, vocab_size)
                )  # (batch_size * num_beams, vocab_size)
1118
1119

                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
1120
1121
1122
                next_scores = tf.reshape(
                    next_scores, (batch_size, num_beams * vocab_size)
                )  # (batch_size, num_beams * vocab_size)
1123

Patrick von Platen's avatar
Patrick von Platen committed
1124
                next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
1125
1126
1127
1128
1129
1130
1131
1132
1133

            assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]

            # next batch beam content
            next_batch_beam = []

            # for each sentence
            for batch_idx in range(batch_size):

1134
                # if we are done with this sentence
1135
1136
1137
1138
1139
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                    assert (
1140
                        eos_token_id is not None and pad_token_id is not None
1141
1142
1143
1144
1145
1146
1147
1148
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next tokens for this sentence
1149
1150
1151
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
                ):
1152
                    # get beam and token IDs
1153
1154
                    beam_id = beam_token_id // vocab_size
                    token_id = beam_token_id % vocab_size
1155

1156
                    effective_beam_id = batch_idx * num_beams + beam_id
1157
                    # add to generated hypotheses if end of sentence or last iteration
1158
                    if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
1159
1160
1161
1162
1163
1164
1165
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                        if is_beam_token_worse_than_top_num_beams:
                            continue
                        generated_hyps[batch_idx].add(
                            tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
                        )
1166
1167
                    else:
                        # add next predicted token if it is not eos_token
1168
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
1169
1170
1171
1172
1173

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

1174
                # Check if were done so that we can save a pad step if all(done)
1175
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
1176
                    tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
1177
1178
                )

1179
1180
1181
1182
1183
                # update next beam content
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)

1184
1185
1186
1187
            # stop when we are done with each sentence
            if all(done):
                break

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
            beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
            beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)

            # re-order batch
            input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
            input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
            # re-order internal states
1198
            if past is not None:
1199
1200
                past = self._reorder_cache(past, beam_idx)

1201
            # extend attention_mask for new generated input if only decoder
1202
1203
1204
1205
1206
            if self.config.is_encoder_decoder is False:
                attention_mask = tf.concat(
                    [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
                )

1207
1208
1209
1210
            # update current length
            cur_len = cur_len + 1

        # finalize all open beam hypotheses and end to generated hypotheses
1211
1212
        for batch_idx in range(batch_size):
            # Add all open beam hypothesis to generated_hyps
1213
1214
1215
            if done[batch_idx]:
                continue
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
1216
1217
            if eos_token_id is not None and all(
                (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
1218
1219
1220
1221
1222
1223
            ):
                assert tf.reduce_all(
                    next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
                    next_scores[:, :num_beams][batch_idx], tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
                )
1224

1225
1226
1227
1228
1229
1230
            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].numpy().item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[batch_idx].add(final_tokens, final_score)
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246

        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

        # select the best hypotheses
        sent_lengths_list = []
        best = []

        # retrieve best hypotheses
        for i, hypotheses in enumerate(generated_hyps):
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths_list.append(len(best_hyp))
                best.append(best_hyp)
1247
1248
1249
        assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
            output_batch_size, len(best)
        )
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260

        sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)

        # shorter batches are filled with pad_token
        if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
            sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
            decoded_list = []

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
                assert sent_lengths[i] == shape_list(hypo)[0]
                # if sent_length is max_len do not pad
                if sent_lengths[i] == sent_max_len:
                    decoded_slice = hypo
                else:
                    # else pad to sent_max_len
                    num_pad_tokens = sent_max_len - sent_lengths[i]
                    padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
                    decoded_slice = tf.concat([hypo, padding], axis=-1)

                    # finish sentence with EOS token
                    if sent_lengths[i] < max_length:
                        decoded_slice = tf.where(
                            tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
                            eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
                            decoded_slice,
                        )
                # add to list
                decoded_list.append(decoded_slice)

1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
            decoded = tf.stack(decoded_list)
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = tf.stack(best)

        return decoded

    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = []
        for layer_past in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` and `mems` is at 2nd position
1295
1296
            reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
            reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
            # check that shape matches
            assert shape_list(reordered_layer_past) == shape_list(layer_past)
            reordered_past.append(reordered_layer_past)
        past = tuple(reordered_past)
        return past


def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
    # create logit penalties for already seen input_ids
    token_penalties = np.ones(shape_list(logits))
    prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
    for i, prev_input_id in enumerate(prev_input_ids):
        logit_penalized = logits[i].numpy()[prev_input_id]
1310
        logit_penalties = np.zeros(logit_penalized.shape)
1311
        # if previous logit score is < 0 then multiply repetition penalty else divide
1312
1313
1314
        logit_penalties[logit_penalized < 0] = repetition_penalty
        logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
        np.put(token_penalties[i], prev_input_id, logit_penalties)
1315
    return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
1316
1317


1318
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
    # Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].numpy().tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
    banned_tokens = []

    def _tokens_match(prev_tokens, tokens):
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        if len(tokens) > len(prev_input_ids):
            # if bad word tokens are longer then prev input_ids they can't be equal
            return False

        if prev_tokens[-len(tokens) :] == tokens:
            # if tokens match
            return True
        else:
            return False

    for prev_input_ids_slice in prev_input_ids:
        banned_tokens_slice = []

        for banned_token_seq in bad_words_ids:
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
                bad_words_ids
            )

            if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
                # if tokens do not match continue
                continue

            banned_tokens_slice.append(banned_token_seq[-1])

        banned_tokens.append(banned_tokens_slice)

    return banned_tokens


1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            Make sure we keep at least min_tokens_to_keep per batch example in the output
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    logits_shape = shape_list(logits)

    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1])  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)

    if top_p < 1.0:
        sorted_indices = tf.argsort(logits, direction="DESCENDING")
        sorted_logits = tf.gather(
            logits, sorted_indices, axis=-1, batch_dims=1
        )  # expects logits to be of dim (batch_size, vocab_size)

        cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p

        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove = tf.concat(
                [
                    tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
                    sorted_indices_to_remove[:, min_tokens_to_keep:],
                ],
                -1,
            )

        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
        sorted_indices_to_remove = tf.concat(
            [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
        )
        # scatter sorted tensors to original indexing
        indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
    return logits


def scatter_values_on_batch_indices(values, batch_indices):
    shape = shape_list(batch_indices)
    # broadcast batch dim to shape
    broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
    # transform batch_indices to pair_indices
    pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
    # scatter values to pair indices
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)


def set_tensor_by_indices_to_value(tensor, indices, value):
    # create value_tensor since tensor value assignment is not possible in TF
    value_tensor = tf.zeros_like(tensor) + value
    return tf.where(indices, value_tensor, tensor)

1442

1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
class BeamHypotheses(object):
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs, cur_len=None):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """

        if len(self) < self.num_beams:
            return False
        elif self.early_stopping:
            return True
        else:
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret


thomwolf's avatar
WIP  
thomwolf committed
1493
class TFConv1D(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
1494
    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
thomwolf's avatar
WIP  
thomwolf committed
1495
1496
1497
        """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
1498
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
1499
        self.nf = nf
thomwolf's avatar
thomwolf committed
1500
        self.nx = nx
thomwolf's avatar
thomwolf committed
1501
        self.initializer_range = initializer_range
thomwolf's avatar
thomwolf committed
1502
1503
1504

    def build(self, input_shape):
        self.weight = self.add_weight(
1505
1506
1507
            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
        )
        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
thomwolf's avatar
thomwolf committed
1508

thomwolf's avatar
WIP  
thomwolf committed
1509
    def call(self, x):
thomwolf's avatar
thomwolf committed
1510
        bz, sl = shape_list(x)[:2]
thomwolf's avatar
thomwolf committed
1511

thomwolf's avatar
thomwolf committed
1512
        x = tf.reshape(x, [-1, self.nx])
thomwolf's avatar
thomwolf committed
1513
        x = tf.matmul(x, self.weight) + self.bias
thomwolf's avatar
thomwolf committed
1514
1515

        x = tf.reshape(x, [bz, sl, self.nf])
thomwolf's avatar
thomwolf committed
1516

thomwolf's avatar
WIP  
thomwolf committed
1517
        return x
thomwolf's avatar
thomwolf committed
1518
1519


thomwolf's avatar
thomwolf committed
1520
1521
1522
class TFSharedEmbeddings(tf.keras.layers.Layer):
    """Construct shared token embeddings.
    """
1523

thomwolf's avatar
thomwolf committed
1524
    def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1525
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1526
1527
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
1528
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
thomwolf's avatar
thomwolf committed
1529
1530

    def build(self, input_shape):
1531
        """Build shared token embedding layer
thomwolf's avatar
thomwolf committed
1532
1533
1534
1535
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
1536
1537
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
Julien Chaumond's avatar
Julien Chaumond committed
1538
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550

    def call(self, inputs, mode="embedding"):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
1551

thomwolf's avatar
thomwolf committed
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [..., hidden_size]
            Returns:
                float32 tensor with shape [..., vocab_size].
        """
        first_dims = shape_list(inputs)[:-1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


thomwolf's avatar
thomwolf committed
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
class TFSequenceSummary(tf.keras.layers.Layer):
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
    """
1596

thomwolf's avatar
thomwolf committed
1597
    def __init__(self, config, initializer_range=0.02, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1598
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
1599

1600
1601
        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1602
1603
1604
1605
1606
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

1607
        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
1608
        if self.has_summary:
1609
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
thomwolf's avatar
thomwolf committed
1610
1611
1612
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
1613
1614
1615
            self.summary = tf.keras.layers.Dense(
                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
            )
thomwolf's avatar
thomwolf committed
1616

1617
        self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
1618
        if self.has_activation:
1619
            self.activation = tf.keras.activations.tanh
thomwolf's avatar
thomwolf committed
1620

1621
        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
1622
        if self.has_first_dropout:
thomwolf's avatar
thomwolf committed
1623
1624
            self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)

1625
        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
1626
        if self.has_last_dropout:
thomwolf's avatar
thomwolf committed
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
            self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)

    def call(self, inputs, training=False):
        """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
                if summary_type == 'cls_index' and cls_index is None:
                    we take the last token of the sequence as classification token
        """
        if not isinstance(inputs, (dict, tuple, list)):
            hidden_states = inputs
            cls_index = None
        elif isinstance(inputs, (tuple, list)):
            hidden_states = inputs[0]
            cls_index = inputs[1] if len(inputs) > 1 else None
            assert len(inputs) <= 2, "Too many inputs."
        else:
1644
            hidden_states = inputs.get("hidden_states")
1645
            cls_index = inputs.get("cls_index", None)
thomwolf's avatar
thomwolf committed
1646

1647
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
1648
            output = hidden_states[:, -1]
1649
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
1650
            output = hidden_states[:, 0]
1651
        elif self.summary_type == "mean":
Lysandre's avatar
Lysandre committed
1652
            output = tf.reduce_mean(hidden_states, axis=1)
1653
        elif self.summary_type == "cls_index":
1654
            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
thomwolf's avatar
thomwolf committed
1655
            if cls_index is None:
1656
1657
1658
                cls_index = tf.fill(
                    hidden_shape[:-2], hidden_shape[-2] - 1
                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
1659
1660
1661
1662
            cls_shape = shape_list(cls_index)
            if len(cls_shape) <= len(hidden_shape) - 2:
                cls_index = cls_index[..., tf.newaxis]
            # else:
1663
1664
            # cls_index = cls_index[..., tf.newaxis]
            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1665
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1666
            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
1667
1668
1669
1670
            output = tf.squeeze(
                output, axis=len(hidden_shape) - 2
            )  # shape of output: (batch, num choices, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1671
1672
            raise NotImplementedError

1673
1674
        if self.has_first_dropout:
            output = self.first_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1675

1676
        if self.has_summary:
1677
            output = self.summary(output)
thomwolf's avatar
thomwolf committed
1678

1679
        if self.has_activation:
thomwolf's avatar
thomwolf committed
1680
1681
            output = self.activation(output)

1682
1683
        if self.has_last_dropout:
            output = self.last_dropout(output, training=training)
thomwolf's avatar
thomwolf committed
1684
1685
1686

        return output

1687

thomwolf's avatar
thomwolf committed
1688
1689
1690
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
thomwolf's avatar
thomwolf committed
1691
    dynamic = tf.shape(x)
thomwolf's avatar
thomwolf committed
1692
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
thomwolf's avatar
thomwolf committed
1693

1694

thomwolf's avatar
thomwolf committed
1695
def get_initializer(initializer_range=0.02):
Julien Chaumond's avatar
Julien Chaumond committed
1696
1697
1698
1699
1700
1701
1702
    """Creates a `tf.initializers.truncated_normal` with the given range.
    Args:
        initializer_range: float, initializer range for stddev.
    Returns:
        TruncatedNormal initializer with stddev = `initializer_range`.
    """
    return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)