modeling_utils.py 95.8 KB
Newer Older
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

import logging
import os
20
import typing
21
22
23

import torch
from torch import nn
24
25
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
26

27
from .activations import get_activation
28
from .configuration_utils import PretrainedConfig
29
from .file_utils import (
Aymeric Augustin's avatar
Aymeric Augustin committed
30
    DUMMY_INPUTS,
31
32
33
34
35
36
37
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    WEIGHTS_NAME,
    cached_path,
    hf_bucket_url,
    is_remote_url,
)
38

Aymeric Augustin's avatar
Aymeric Augustin committed
39

40
41
logger = logging.getLogger(__name__)

42

thomwolf's avatar
thomwolf committed
43
44
45
46
47
48
49
try:
    from torch.nn import Identity
except ImportError:
    # Older PyTorch compatibility
    class Identity(nn.Module):
        r"""A placeholder identity operator that is argument-insensitive.
        """
50

thomwolf's avatar
thomwolf committed
51
        def __init__(self, *args, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
52
            super().__init__()
thomwolf's avatar
thomwolf committed
53
54
55
56

        def forward(self, input):
            return input

57

58
class ModuleUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
59
60
61
62
63
64
65
66
67
68
69
    """
    A few utilities for torch.nn.Modules, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the module.
        """
        params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
        return sum(p.numel() for p in params)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @staticmethod
    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_pre_forward = mem.rss
        return None

    @staticmethod
    def _hook_rss_memory_post_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_post_forward = mem.rss
        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
        return None

    def add_memory_hooks(self):
        """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
            Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
        """
        for module in self.modules():
            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
            module.register_forward_hook(self._hook_rss_memory_post_forward)
        self.reset_memory_hooks_state()

    def reset_memory_hooks_state(self):
        for module in self.modules():
            module.mem_rss_diff = 0
            module.mem_rss_post_forward = 0
            module.mem_rss_pre_forward = 0

111
112
113
114
    @property
    def device(self):
        return next(self.parameters()).device

Julien Chaumond's avatar
Julien Chaumond committed
115

116
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
117
118
    r""" Base class for all models.

119
        :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
120
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
121
122

        Class attributes (overridden by derived classes):
123
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
124
125
126
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

127
128
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
129
130
131
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
132
    """
133
    config_class = None
134
135
136
    pretrained_model_archive_map = {}
    base_model_prefix = ""

137
138
139
140
141
142
143
    @property
    def dummy_inputs(self):
        """ Dummy inputs to do a forward pass in the network.

        Returns:
            torch.Tensor with dummy inputs
        """
144
        return {"input_ids": torch.tensor(DUMMY_INPUTS)}
145

146
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
147
        super().__init__()
148
149
150
151
152
153
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
154
155
                )
            )
thomwolf's avatar
thomwolf committed
156
        # Save config in model
157
158
        self.config = config

159
160
161
    @property
    def base_model(self):
        return getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
162

thomwolf's avatar
thomwolf committed
163
    def get_input_embeddings(self):
164
165
166
167
168
169
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
170
        """
171
        base_model = getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
172
173
174
175
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
176

thomwolf's avatar
thomwolf committed
177
    def set_input_embeddings(self, value):
178
179
180
181
182
183
        """
        Set model's input embeddings

        Args:
            value (:obj:`nn.Module`):
                A module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
184
185
186
187
188
189
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
190

thomwolf's avatar
thomwolf committed
191
    def get_output_embeddings(self):
192
193
194
195
196
197
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping hidden states to vocabulary.
thomwolf's avatar
thomwolf committed
198
        """
199
        return None  # Overwrite for models with output embeddings
thomwolf's avatar
thomwolf committed
200

201
    def tie_weights(self):
202
203
204
205
        """
        Tie the weights between the input embeddings and the output embeddings.
        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
        the weights instead.
thomwolf's avatar
thomwolf committed
206
        """
thomwolf's avatar
thomwolf committed
207
208
209
        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
thomwolf's avatar
thomwolf committed
210

211
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
thomwolf's avatar
thomwolf committed
212
213
214
        """ Tie or clone module weights depending of weither we are using TorchScript or not
        """
        if self.config.torchscript:
215
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
thomwolf's avatar
thomwolf committed
216
        else:
217
            output_embeddings.weight = input_embeddings.weight
thomwolf's avatar
thomwolf committed
218

Sam Shleifer's avatar
Sam Shleifer committed
219
        if getattr(output_embeddings, "bias", None) is not None:
220
221
            output_embeddings.bias.data = torch.nn.functional.pad(
                output_embeddings.bias.data,
Patrick von Platen's avatar
Patrick von Platen committed
222
                (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
223
224
                "constant",
                0,
225
            )
226
        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
227
            output_embeddings.out_features = input_embeddings.num_embeddings
228

thomwolf's avatar
thomwolf committed
229
230
    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
231
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
thomwolf's avatar
thomwolf committed
232

233
234
235
        Arguments:

            new_num_tokens: (`optional`) int:
236
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
237
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
thomwolf's avatar
thomwolf committed
238

thomwolf's avatar
thomwolf committed
239
        Return: ``torch.nn.Embeddings``
240
            Pointer to the input tokens Embeddings Module of the model
thomwolf's avatar
thomwolf committed
241
242
        """
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
thomwolf's avatar
thomwolf committed
243
244
245
        model_embeds = base_model._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds
thomwolf's avatar
thomwolf committed
246
247
248
249
250
251

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens

        # Tie weights again if needed
252
        self.tie_weights()
thomwolf's avatar
thomwolf committed
253

thomwolf's avatar
thomwolf committed
254
255
        return model_embeds

256
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
257
258
259
260
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_input_embeddings(new_embeddings)
        return self.get_input_embeddings()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
        """ Build a resized Embedding Module from a provided token Embedding Module.
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
        Return: ``torch.nn.Embeddings``
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
        if new_num_tokens is None:
            return old_embeddings

        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        if old_num_tokens == new_num_tokens:
            return old_embeddings

        # Build new embeddings
        new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        new_embeddings.to(old_embeddings.weight.device)

        # initialize all new embeddings (in particular added tokens)
        self._init_weights(new_embeddings)

290
        # Copy token embeddings from the previous weights
291
292
293
294
295
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        return new_embeddings

296
297
298
299
300
301
302
303
304
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

305
306
307
        # Tie weights if needed
        self.tie_weights()

thomwolf's avatar
thomwolf committed
308
309
    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.
310
311
312
313

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
314
                E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
315
        """
316
        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
317
        for layer, heads in heads_to_prune.items():
318
319
320
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON

321
        self.base_model._prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
322

323
    def save_pretrained(self, save_directory):
324
        """ Save a model and its configuration file to a directory, so that it
325
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
326
327
328

            Arguments:
                save_directory: directory to which to save.
329
        """
330
331
332
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
333

Julien Chaumond's avatar
Julien Chaumond committed
334
        # Only save the model itself if we are using distributed training
335
        model_to_save = self.module if hasattr(self, "module") else self
336

Julien Chaumond's avatar
Julien Chaumond committed
337
338
339
        # Attach architecture to the config
        model_to_save.config.architectures = [model_to_save.__class__.__name__]

340
341
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
342
343
344
345
346
347
348
349
350
351
352
353
354

        if hasattr(self.config, "xla_device") and self.config.xla_device:
            import torch_xla.core.xla_model as xm

            if xm.is_master_ordinal():
                # Save configuration file
                model_to_save.config.save_pretrained(save_directory)
            # xm.save takes care of saving only from master
            xm.save(model_to_save.state_dict(), output_model_file)
        else:
            model_to_save.config.save_pretrained(save_directory)
            torch.save(model_to_save.state_dict(), output_model_file)

thomwolf's avatar
thomwolf committed
355
        logger.info("Model weights saved in {}".format(output_model_file))
356

357
    @classmethod
358
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
359
360
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

361
362
363
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

364
365
366
367
368
        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

369
370
        Parameters:
            pretrained_model_name_or_path: either:
Lysandre's avatar
Fixes  
Lysandre committed
371
372
373
374
375
              - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
              - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
              - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
              - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
              - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
376
377
378
379

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

380
            config: (`optional`) one of:
Lysandre's avatar
Fixes  
Lysandre committed
381
382
                - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
383
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
Lysandre's avatar
Fixes  
Lysandre committed
384
385
386
                    - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                    - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                    - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
387
388
389

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
390
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
391
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
392
393

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
394
395
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
396

397
398
399
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

400
401
402
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

403
404
405
406
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

407
            output_loading_info: (`optional`) boolean:
thomwolf's avatar
thomwolf committed
408
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
409
410
411
412
413

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
414
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
415
416

        Examples::
thomwolf's avatar
thomwolf committed
417

Lysandre's avatar
Lysandre committed
418
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
419
420
421
422
423
424
425
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
426

427
        """
428
429
430
431
432
433
434
435
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
436
        local_files_only = kwargs.pop("local_files_only", False)
thomwolf's avatar
thomwolf committed
437

438
439
440
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
441
            config, model_kwargs = cls.config_class.from_pretrained(
442
443
444
445
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
446
                force_download=force_download,
447
                resume_download=resume_download,
448
                proxies=proxies,
449
                local_files_only=local_files_only,
450
                **kwargs,
451
452
453
            )
        else:
            model_kwargs = kwargs
454

thomwolf's avatar
thomwolf committed
455
        # Load model
thomwolf's avatar
thomwolf committed
456
        if pretrained_model_name_or_path is not None:
457
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
458
459
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
460
461
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
thomwolf's avatar
thomwolf committed
462
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
thomwolf's avatar
thomwolf committed
463
464
465
466
467
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
468
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
thomwolf's avatar
thomwolf committed
469
                else:
470
471
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False".format(
Patrick von Platen's avatar
Patrick von Platen committed
472
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
Patrick von Platen's avatar
Patrick von Platen committed
473
                            pretrained_model_name_or_path,
474
475
                        )
                    )
476
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
477
                archive_file = pretrained_model_name_or_path
478
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
479
480
481
482
483
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index"
                )
484
                archive_file = pretrained_model_name_or_path + ".index"
485
            else:
thomwolf's avatar
thomwolf committed
486
                archive_file = hf_bucket_url(
Patrick von Platen's avatar
Patrick von Platen committed
487
                    pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
thomwolf's avatar
thomwolf committed
488
                )
489

thomwolf's avatar
thomwolf committed
490
491
            # redirect to the cache, if necessary
            try:
492
493
494
495
496
497
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
498
                    local_files_only=local_files_only,
499
                )
thomwolf's avatar
thomwolf committed
500
            except EnvironmentError:
thomwolf's avatar
thomwolf committed
501
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
502
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
thomwolf's avatar
thomwolf committed
503
                else:
504
505
506
                    msg = (
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url to model weight files named one of {} but "
thomwolf's avatar
thomwolf committed
507
                        "couldn't find any such file at this path or url.".format(
thomwolf's avatar
thomwolf committed
508
                            pretrained_model_name_or_path,
509
                            ", ".join(cls.pretrained_model_archive_map.keys()),
thomwolf's avatar
thomwolf committed
510
                            archive_file,
511
512
513
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
                        )
                    )
thomwolf's avatar
thomwolf committed
514
515
                raise EnvironmentError(msg)

thomwolf's avatar
thomwolf committed
516
517
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
518
            else:
519
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
520
        else:
thomwolf's avatar
thomwolf committed
521
            resolved_archive_file = None
522
523

        # Instantiate model.
524
        model = cls(config, *model_args, **model_kwargs)
thomwolf's avatar
thomwolf committed
525

526
        if state_dict is None and not from_tf:
527
            try:
528
                state_dict = torch.load(resolved_archive_file, map_location="cpu")
529
            except Exception:
530
531
532
533
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )
534

535
536
537
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
538
539

        if from_tf:
540
            if resolved_archive_file.endswith(".index"):
541
542
543
544
545
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
546
                    from transformers import load_tf2_checkpoint_in_pytorch_model
547

548
                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
549
                except ImportError:
550
551
552
553
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
554
                    raise
555
556
557
558
559
560
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
561
562
563
564
                if "gamma" in key:
                    new_key = key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = key.replace("beta", "bias")
565
566
567
568
569
570
571
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
572
            metadata = getattr(state_dict, "_metadata", None)
573
574
575
576
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

577
578
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
Julien Chaumond's avatar
Julien Chaumond committed
579
            def load(module: nn.Module, prefix=""):
580
581
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
Patrick von Platen's avatar
Patrick von Platen committed
582
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
583
                )
584
585
                for name, child in module._modules.items():
                    if child is not None:
586
                        load(child, prefix + name + ".")
587
588

            # Make sure we are able to load base models as well as derived models (with heads)
589
            start_prefix = ""
590
            model_to_load = model
591
592
593
594
595
596
597
            if not hasattr(model, cls.base_model_prefix) and any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
                start_prefix = cls.base_model_prefix + "."
            if hasattr(model, cls.base_model_prefix) and not any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
598
599
600
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
601
602
603
604
605
606
607
608
609

            if model.__class__.__name__ != model_to_load.__class__.__name__:
                base_model_state_dict = model_to_load.state_dict().keys()
                head_model_state_dict_without_base_prefix = [
                    key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
                ]

                missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

610
            if len(missing_keys) > 0:
611
612
613
614
615
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".format(
                        model.__class__.__name__, missing_keys
                    )
                )
616
            if len(unexpected_keys) > 0:
617
618
619
620
621
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys
                    )
                )
622
            if len(error_msgs) > 0:
623
624
625
626
627
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)
                    )
                )
628
        model.tie_weights()  # make sure token embedding weights are still tied if needed
629

630
631
632
        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

thomwolf's avatar
thomwolf committed
633
        if output_loading_info:
634
635
636
637
638
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
thomwolf's avatar
thomwolf committed
639
640
            return model, loading_info

641
642
643
644
645
646
        if hasattr(config, "xla_device") and config.xla_device:
            import torch_xla.core.xla_model as xm

            model = xm.send_cpu_data_to_device(model, xm.xla_device())
            model = model.to(xm.xla_device())

647
648
        return model

thomwolf's avatar
thomwolf committed
649
650
651
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

patrickvonplaten's avatar
patrickvonplaten committed
652
653
654
    def prepare_scores_for_generation(self, scores, **kwargs):
        return scores

655
    def _do_output_past(self, outputs):
Sam Shleifer's avatar
Sam Shleifer committed
656
657
658
659
660
661
        """During generation, decide whether to pass the `past` variable to the next forward pass."""
        has_output_past = getattr(self.config, "output_past", False)
        mem_len = getattr(self.config, "mem_len", 0)
        if len(outputs) <= 1:
            return False
        if mem_len > 0 or has_output_past:
662
            return True
663
664
        return False

Sam Shleifer's avatar
Sam Shleifer committed
665
666
667
668
669
670
671
672
673
674
    def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty

thomwolf's avatar
thomwolf committed
675
    @torch.no_grad()
676
677
678
679
    def generate(
        self,
        input_ids=None,
        max_length=None,
680
        min_length=None,
681
682
        do_sample=None,
        early_stopping=None,
683
684
685
686
687
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
688
        bad_words_ids=None,
689
690
        bos_token_id=None,
        pad_token_id=None,
691
        eos_token_id=None,
692
        length_penalty=None,
Patrick von Platen's avatar
Patrick von Platen committed
693
        no_repeat_ngram_size=None,
694
        num_return_sequences=None,
Patrick von Platen's avatar
Patrick von Platen committed
695
        attention_mask=None,
696
        decoder_start_token_id=None,
697
    ):
698
        r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
thomwolf's avatar
thomwolf committed
699

700
701
702
703
704
705
706
        Adapted in part from `Facebook's XLM beam search code`_.

        .. _`Facebook's XLM beam search code`:
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


        Parameters:
thomwolf's avatar
thomwolf committed
707

708
            input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
thomwolf's avatar
thomwolf committed
709
                The sequence used as a prompt for the generation. If `None` the method initializes
710
711
712
                it as an empty `torch.LongTensor` of shape `(1,)`.

            max_length: (`optional`) int
713
714
715
716
                The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.

            min_length: (`optional`) int
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
717
718

            do_sample: (`optional`) bool
719
720
721
722
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

            early_stopping: (`optional`) bool
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
723
724
725
726
727

            num_beams: (`optional`) int
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

            temperature: (`optional`) float
Sam Shleifer's avatar
Sam Shleifer committed
728
                The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
729
730

            top_k: (`optional`) int
thomwolf's avatar
thomwolf committed
731
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
732
733

            top_p: (`optional`) float
thomwolf's avatar
thomwolf committed
734
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
735
736
737
738

            repetition_penalty: (`optional`) float
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

739
740
741
            pad_token_id: (`optional`) int
                Padding token. Default to specicic model pad_token_id or None if it does not exist.

742
            bos_token_id: (`optional`) int
743
                BOS token. Defaults to `bos_token_id` as defined in the models config.
744

745
746
            eos_token_id: (`optional`) int
                EOS token. Defaults to `eos_token_id` as defined in the models config.
747

748
            length_penalty: (`optional`) float
thomwolf's avatar
thomwolf committed
749
                Exponential penalty to the length. Default to 1.
750

751
752
            no_repeat_ngram_size: (`optional`) int
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
753
754
            bad_words_ids: (`optional`) list of lists of int
                `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
755

756
757
758
            num_return_sequences: (`optional`) int
                The number of independently computed returned sequences for each element in the batch. Default to 1.

759
760
761
762
763
764
765
766
767
768
769
770
            attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
                Mask to avoid performing attention on padding token indices.
                Mask values selected in ``[0, 1]``:
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
                Defaults to `None`.

            `What are attention masks? <../glossary.html#attention-mask>`__

            decoder_start_token_id=None: (`optional`) int
                If an encoder-decoder model starts decoding with a different token than BOS.
                Defaults to `None` and is changed to `BOS` later.

771
772
773
774
775
        Return:

            output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

776
777
778
779
        Examples::

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
780
            outputs = model.generate(max_length=40)  # do greedy decoding
781
782
783
784
785
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
786
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
787
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
788
            for i in range(3): #  3 output sequences were generated
789
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
790
791
792
793

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
794
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
795
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
796
797
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
798
799
800
801

            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
802
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
803
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
804
805
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

806
807
808
809
810
811
            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
thomwolf's avatar
thomwolf committed
812
813
814
815
        """

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
816
817
            raise AttributeError(
                "You tried to generate sequences with a model that does not have a LM Head."
818
                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
819
            )
thomwolf's avatar
thomwolf committed
820

821
        max_length = max_length if max_length is not None else self.config.max_length
822
        min_length = min_length if min_length is not None else self.config.min_length
823
        do_sample = do_sample if do_sample is not None else self.config.do_sample
Patrick von Platen's avatar
Patrick von Platen committed
824
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
825
826
827
828
829
830
831
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
832
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
833
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
Patrick von Platen's avatar
Patrick von Platen committed
834
835
836
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
837
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
838
839
840
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )
841
842
843
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
thomwolf's avatar
thomwolf committed
844
845
846

        if input_ids is not None:
            batch_size = input_ids.shape[0]  # overriden by the input batch_size
thomwolf's avatar
thomwolf committed
847
848
        else:
            batch_size = 1
thomwolf's avatar
thomwolf committed
849

Sam Shleifer's avatar
Sam Shleifer committed
850
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
851
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
thomwolf's avatar
thomwolf committed
852
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
Patrick von Platen's avatar
Patrick von Platen committed
853
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
Sam Shleifer's avatar
Sam Shleifer committed
854
855
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
        assert temperature > 0, "`temperature` should be strictly positive."
856
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
thomwolf's avatar
thomwolf committed
857
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
thomwolf's avatar
thomwolf committed
858
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
859
860
861
862
863
864
        assert input_ids is not None or (
            isinstance(bos_token_id, int) and bos_token_id >= 0
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
        assert pad_token_id is None or (
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
        ), "`pad_token_id` should be a positive integer."
865
866
867
        assert (eos_token_id is None) or (
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
        ), "`eos_token_id` should be a positive integer."
Sam Shleifer's avatar
Sam Shleifer committed
868
        assert length_penalty > 0, "`length_penalty` should be strictly positive."
Patrick von Platen's avatar
Patrick von Platen committed
869
870
871
        assert (
            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
        ), "`no_repeat_ngram_size` should be a positive integer."
872
873
        assert (
            isinstance(num_return_sequences, int) and num_return_sequences > 0
Sam Shleifer's avatar
Sam Shleifer committed
874
        ), "`num_return_sequences` should be a strictly positive integer."
875
876
877
        assert (
            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
thomwolf's avatar
thomwolf committed
878
879

        if input_ids is None:
880
881
882
883
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
                "you should either supply a context to complete as `input_ids` input "
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
            )
884
            input_ids = torch.full(
Patrick von Platen's avatar
Patrick von Platen committed
885
                (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
886
            )
thomwolf's avatar
thomwolf committed
887
        else:
888
            assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
thomwolf's avatar
thomwolf committed
889

890
        # not allow to duplicate outputs when greedy decoding
891
892
893
894
895
896
897
898
899
900
901
902
903
        if do_sample is False:
            if num_beams == 1:
                # no_beam_search greedy generation conditions
                assert (
                    num_return_sequences == 1
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

            else:
                # beam_search greedy generation conditions
                assert (
                    num_beams >= num_return_sequences
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

Patrick von Platen's avatar
Patrick von Platen committed
904
        # create attention mask if necessary
patrickvonplaten's avatar
patrickvonplaten committed
905
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
Patrick von Platen's avatar
Patrick von Platen committed
906
907
908
909
910
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
            attention_mask = input_ids.ne(pad_token_id).long()
        elif attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)

911
        # set pad_token_id to eos_token_id if not set. Important that this is done after
Patrick von Platen's avatar
Patrick von Platen committed
912
        # attention_mask is created
913
        if pad_token_id is None and eos_token_id is not None:
914
            logger.warning(
915
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
916
            )
917
            pad_token_id = eos_token_id
918

thomwolf's avatar
thomwolf committed
919
920
921
        # current position and vocab size
        vocab_size = self.config.vocab_size

922
923
        # set effective batch size and effective batch multiplier according to do_sample
        if do_sample:
thomwolf's avatar
thomwolf committed
924
            effective_batch_size = batch_size * num_return_sequences
925
            effective_batch_mult = num_return_sequences
thomwolf's avatar
thomwolf committed
926
927
        else:
            effective_batch_size = batch_size
928
929
            effective_batch_mult = 1

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
        if self.config.is_encoder_decoder:
            if decoder_start_token_id is None:
                decoder_start_token_id = bos_token_id

            assert (
                decoder_start_token_id is not None
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

            # get encoder and store encoder outputs
            encoder = self.get_encoder()

            encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

945
946
947
948
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
        if num_return_sequences > 1 or num_beams > 1:
            input_ids_len = input_ids.shape[-1]
            input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
949
950
951
            attention_mask = attention_mask.unsqueeze(1).expand(
                batch_size, effective_batch_mult * num_beams, input_ids_len
            )
Patrick von Platen's avatar
Patrick von Platen committed
952

953
954
955
            input_ids = input_ids.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
Patrick von Platen's avatar
Patrick von Platen committed
956
957
958
            attention_mask = attention_mask.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
959

Patrick von Platen's avatar
Patrick von Platen committed
960
        if self.config.is_encoder_decoder:
961
            # create empty decoder_input_ids
Patrick von Platen's avatar
Patrick von Platen committed
962
963
            input_ids = torch.full(
                (effective_batch_size * num_beams, 1),
964
                decoder_start_token_id,
Patrick von Platen's avatar
Patrick von Platen committed
965
966
967
                dtype=torch.long,
                device=next(self.parameters()).device,
            )
968
            cur_len = 1
969

970
            assert (
971
972
973
974
975
                batch_size == encoder_outputs[0].shape[0]
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
            expanded_batch_idxs = (
976
977
978
979
980
981
                torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, num_beams * effective_batch_mult)
                .view(-1)
                .to(input_ids.device)
            )
982
983
            # expand encoder_outputs
            encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
984

Patrick von Platen's avatar
Patrick von Platen committed
985
        else:
986
            encoder_outputs = None
Patrick von Platen's avatar
Patrick von Platen committed
987
988
            cur_len = input_ids.shape[-1]

thomwolf's avatar
thomwolf committed
989
        if num_beams > 1:
990
991
            output = self._generate_beam_search(
                input_ids,
992
993
994
995
996
997
998
999
1000
1001
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                early_stopping=early_stopping,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
1002
                bad_words_ids=bad_words_ids,
1003
1004
                bos_token_id=bos_token_id,
                pad_token_id=pad_token_id,
1005
                decoder_start_token_id=decoder_start_token_id,
1006
                eos_token_id=eos_token_id,
1007
1008
1009
1010
1011
                batch_size=effective_batch_size,
                num_return_sequences=num_return_sequences,
                length_penalty=length_penalty,
                num_beams=num_beams,
                vocab_size=vocab_size,
1012
                encoder_outputs=encoder_outputs,
1013
                attention_mask=attention_mask,
1014
            )
thomwolf's avatar
thomwolf committed
1015
        else:
1016
1017
            output = self._generate_no_beam_search(
                input_ids,
1018
1019
1020
1021
1022
1023
1024
1025
1026
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
1027
                bad_words_ids=bad_words_ids,
1028
                bos_token_id=bos_token_id,
1029
                pad_token_id=pad_token_id,
1030
                decoder_start_token_id=decoder_start_token_id,
1031
                eos_token_id=eos_token_id,
1032
                batch_size=effective_batch_size,
1033
                encoder_outputs=encoder_outputs,
1034
                attention_mask=attention_mask,
1035
            )
thomwolf's avatar
thomwolf committed
1036
1037

        return output
thomwolf's avatar
thomwolf committed
1038

1039
1040
1041
1042
1043
    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1044
        min_length,
1045
1046
1047
1048
1049
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1050
        no_repeat_ngram_size,
1051
        bad_words_ids,
1052
        bos_token_id,
1053
        pad_token_id,
1054
        eos_token_id,
1055
        decoder_start_token_id,
1056
        batch_size,
1057
        encoder_outputs,
1058
        attention_mask,
1059
    ):
thomwolf's avatar
thomwolf committed
1060
        """ Generate sequences for each example without beam search (num_beams == 1).
1061
1062
            All returned sequence are generated independantly.
        """
1063
        # length of generated sentences / unfinished sentences
thomwolf's avatar
thomwolf committed
1064
        unfinished_sents = input_ids.new(batch_size).fill_(1)
1065
        sent_lengths = input_ids.new(batch_size).fill_(max_length)
thomwolf's avatar
thomwolf committed
1066

1067
1068
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

thomwolf's avatar
thomwolf committed
1069
        while cur_len < max_length:
1070
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
Sam Shleifer's avatar
Sam Shleifer committed
1071

thomwolf's avatar
thomwolf committed
1072
1073
1074
            outputs = self(**model_inputs)
            next_token_logits = outputs[0][:, -1, :]

patrickvonplaten's avatar
patrickvonplaten committed
1075
            # if model has past, then set the past variable to speed up decoding
1076
            if self._do_output_past(outputs):
1077
1078
                past = outputs[1]

thomwolf's avatar
thomwolf committed
1079
1080
            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
Sam Shleifer's avatar
Sam Shleifer committed
1081
                self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
thomwolf's avatar
thomwolf committed
1082

Patrick von Platen's avatar
Patrick von Platen committed
1083
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1084
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
Patrick von Platen's avatar
Patrick von Platen committed
1085
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1086
1087
1088
1089
1090
1091
1092
1093
                banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
                for batch_idx in range(batch_size):
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

Patrick von Platen's avatar
Patrick von Platen committed
1094
                for batch_idx in range(batch_size):
1095
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1096

patrickvonplaten's avatar
patrickvonplaten committed
1097
            # set eos token prob to zero if min_length is not reached
1098
1099
            if eos_token_id is not None and cur_len < min_length:
                next_token_logits[:, eos_token_id] = -float("inf")
1100

thomwolf's avatar
thomwolf committed
1101
1102
            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
1103
                if temperature != 1.0:
thomwolf's avatar
thomwolf committed
1104
1105
1106
1107
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
1108
1109
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
thomwolf's avatar
thomwolf committed
1110
1111
            else:
                # Greedy decoding
1112
                next_token = torch.argmax(next_token_logits, dim=-1)
thomwolf's avatar
thomwolf committed
1113
1114

            # update generations and finished sentences
1115
1116
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
1117
1118
1119
1120
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

1121
            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
1122

1123
1124
1125
1126
1127
1128
1129
            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents.mul_((~eos_in_sents).long())
1130

thomwolf's avatar
thomwolf committed
1131
1132
1133
1134
            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

patrickvonplaten's avatar
patrickvonplaten committed
1135
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1136
            if self.config.is_encoder_decoder is False:
1137
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1138
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1139
                )
Patrick von Platen's avatar
Patrick von Platen committed
1140

1141
1142
            cur_len = cur_len + 1

1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        # if there are different sentences lengths in the batch, some batches have to be padded
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
            # finished sents are filled with pad_token
            decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
        else:
            decoded = input_ids

        for hypo_idx, hypo in enumerate(input_ids):
            decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
1153

1154
        return decoded
thomwolf's avatar
thomwolf committed
1155

1156
1157
1158
1159
1160
    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1161
        min_length,
1162
        do_sample,
Patrick von Platen's avatar
Patrick von Platen committed
1163
        early_stopping,
1164
1165
1166
1167
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1168
        no_repeat_ngram_size,
1169
        bad_words_ids,
Patrick von Platen's avatar
Patrick von Platen committed
1170
        bos_token_id,
1171
        pad_token_id,
1172
        eos_token_id,
1173
        decoder_start_token_id,
1174
        batch_size,
1175
        num_return_sequences,
1176
1177
1178
        length_penalty,
        num_beams,
        vocab_size,
1179
        encoder_outputs,
Patrick von Platen's avatar
Patrick von Platen committed
1180
        attention_mask,
1181
    ):
thomwolf's avatar
thomwolf committed
1182
        """ Generate sequences for each example with beam search.
1183
        """
thomwolf's avatar
thomwolf committed
1184
1185

        # generated hypotheses
1186
        generated_hyps = [
1187
1188
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
            for _ in range(batch_size)
1189
        ]
thomwolf's avatar
thomwolf committed
1190
1191
1192

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
patrickvonplaten's avatar
patrickvonplaten committed
1193
1194

        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
Patrick von Platen's avatar
Patrick von Platen committed
1195
1196
        if do_sample is False:
            beam_scores[:, 1:] = -1e9
1197
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
thomwolf's avatar
thomwolf committed
1198
1199

        # cache compute states
1200
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models
thomwolf's avatar
thomwolf committed
1201
1202
1203
1204
1205

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
1206
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
1207
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
1208
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)
1209

patrickvonplaten's avatar
patrickvonplaten committed
1210
            # if model has past, then set the past variable to speed up decoding
1211
            if self._do_output_past(outputs):
1212
                past = outputs[1]
thomwolf's avatar
thomwolf committed
1213

1214
1215
            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
1216
                self.enforce_repetition_penalty_(
Patrick von Platen's avatar
Patrick von Platen committed
1217
                    next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
1218
                )
thomwolf's avatar
thomwolf committed
1219

patrickvonplaten's avatar
patrickvonplaten committed
1220
1221
1222
1223
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature

            scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
Patrick von Platen's avatar
Patrick von Platen committed
1224
            if self.config.is_encoder_decoder and do_sample is False:
1225
1226
                # TODO (PVP) still a bit hacky here - there might be a better solutino
                scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
patrickvonplaten's avatar
patrickvonplaten committed
1227
1228

            # set eos token prob to zero if min_length is not reached
1229
1230
            if eos_token_id is not None and cur_len < min_length:
                scores[:, eos_token_id] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1231

Patrick von Platen's avatar
Patrick von Platen committed
1232
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1233
1234
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                num_batch_hypotheses = batch_size * num_beams
Patrick von Platen's avatar
Patrick von Platen committed
1235
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1236
                banned_batch_tokens = calc_banned_ngram_tokens(
1237
1238
                    input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
                )
patrickvonplaten's avatar
patrickvonplaten committed
1239
                for i, banned_tokens in enumerate(banned_batch_tokens):
1240
                    scores[i, banned_tokens] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1241

1242
1243
1244
1245
1246
1247
1248
            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

                for i, banned_tokens in enumerate(banned_tokens):
                    scores[i, banned_tokens] = -float("inf")

1249
1250
1251
            assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
                scores.shape, (batch_size * num_beams, vocab_size)
            )
1252

1253
            if do_sample:
1254
                _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
1255
                # Top-p/top-k filtering
1256
1257
                _scores = top_k_top_p_filtering(
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
1258
                )  # (batch_size * num_beams, vocab_size)
1259
1260
1261
1262
1263
                # re-organize to group the beam together to sample from all beam_idxs
                _scores = _scores.contiguous().view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)

1264
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
1265
1266
                probs = F.softmax(_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
1267
                # Compute next scores
1268
                next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
1269
1270
1271
                # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
                next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
patrickvonplaten's avatar
patrickvonplaten committed
1272

1273
            else:
1274
                next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
patrickvonplaten's avatar
patrickvonplaten committed
1275

1276
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
1277
1278
1279
                next_scores = next_scores.view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)
1280

1281
                next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
thomwolf's avatar
thomwolf committed
1282

1283
            assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
thomwolf's avatar
thomwolf committed
1284
1285
1286
1287
1288

            # next batch beam content
            next_batch_beam = []

            # for each sentence
1289
            for batch_idx in range(batch_size):
thomwolf's avatar
thomwolf committed
1290
1291

                # if we are done with this sentence
1292
1293
1294
1295
1296
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                    assert (
1297
                        eos_token_id is not None and pad_token_id is not None
1298
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
thomwolf's avatar
thomwolf committed
1299
1300
1301
1302
1303
1304
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

1305
                # next tokens for this sentence
1306
1307
1308
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
                ):
1309
                    # get beam and token IDs
1310
1311
                    beam_id = beam_token_id // vocab_size
                    token_id = beam_token_id % vocab_size
thomwolf's avatar
thomwolf committed
1312

1313
                    effective_beam_id = batch_idx * num_beams + beam_id
1314
1315
                    # add to generated hypotheses if end of sentence or last iteration
                    if (eos_token_id is not None) and (token_id.item() == eos_token_id):
1316
1317
1318
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                        if is_beam_token_worse_than_top_num_beams:
patrickvonplaten's avatar
patrickvonplaten committed
1319
                            continue
1320
                        generated_hyps[batch_idx].add(
1321
                            input_ids[effective_beam_id].clone(), beam_token_score.item(),
1322
                        )
thomwolf's avatar
thomwolf committed
1323
                    else:
1324
                        # add next predicted token if it is not eos_token
1325
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
thomwolf's avatar
thomwolf committed
1326
1327
1328
1329
1330

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

patrickvonplaten's avatar
patrickvonplaten committed
1331
1332
1333
1334
1335
                # Check if were done so that we can save a pad step if all(done)
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                    next_scores[batch_idx].max().item(), cur_len=cur_len
                )

thomwolf's avatar
thomwolf committed
1336
                # update next beam content
1337
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
thomwolf's avatar
thomwolf committed
1338
                next_batch_beam.extend(next_sent_beam)
1339
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)
thomwolf's avatar
thomwolf committed
1340

patrickvonplaten's avatar
patrickvonplaten committed
1341
1342
1343
1344
            # stop when we are done with each sentence
            if all(done):
                break

thomwolf's avatar
thomwolf committed
1345
1346
1347
            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
1348
            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
thomwolf's avatar
thomwolf committed
1349
1350
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])

1351
            # re-order batch
thomwolf's avatar
thomwolf committed
1352
            input_ids = input_ids[beam_idx, :]
1353
            input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
1354
            # re-order internal states
1355
            if past is not None:
Sam Shleifer's avatar
Sam Shleifer committed
1356
                past = self._reorder_cache(past, beam_idx)
thomwolf's avatar
thomwolf committed
1357

patrickvonplaten's avatar
patrickvonplaten committed
1358
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1359
            if self.config.is_encoder_decoder is False:
1360
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1361
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1362
                )
Patrick von Platen's avatar
Patrick von Platen committed
1363

1364
1365
1366
1367
            # update current length
            cur_len = cur_len + 1

        # finalize all open beam hypotheses and end to generated hypotheses
1368
        for batch_idx in range(batch_size):
1369
1370
            if done[batch_idx]:
                continue
1371

1372
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
1373
1374
            if eos_token_id is not None and all(
                (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
1375
1376
1377
1378
            ):
                assert torch.all(
                    next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
Patrick von Platen's avatar
Patrick von Platen committed
1379
                    next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
1380
1381
1382
1383
1384
1385
1386
1387
                )

            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[batch_idx].add(final_tokens, final_score)
thomwolf's avatar
thomwolf committed
1388

1389
1390
1391
        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
thomwolf's avatar
thomwolf committed
1392
1393

        # select the best hypotheses
1394
        sent_lengths = input_ids.new(output_batch_size)
thomwolf's avatar
thomwolf committed
1395
        best = []
thomwolf's avatar
thomwolf committed
1396

1397
        # retrieve best hypotheses
thomwolf's avatar
thomwolf committed
1398
        for i, hypotheses in enumerate(generated_hyps):
1399
1400
1401
1402
1403
1404
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                effective_batch_idx = output_num_return_sequences_per_batch * i + j
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[effective_batch_idx] = len(best_hyp)
                best.append(best_hyp)
thomwolf's avatar
thomwolf committed
1405

1406
1407
1408
        # shorter batches are filled with pad_token
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
1409
            sent_max_len = min(sent_lengths.max().item() + 1, max_length)
1410
            decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
1411
1412
1413
1414
1415

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
                decoded[i, : sent_lengths[i]] = hypo
                if sent_lengths[i] < max_length:
1416
                    decoded[i, sent_lengths[i]] = eos_token_id
1417
1418
1419
1420
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
thomwolf's avatar
thomwolf committed
1421

Patrick von Platen's avatar
Patrick von Platen committed
1422
        return decoded
thomwolf's avatar
thomwolf committed
1423

Patrick von Platen's avatar
Patrick von Platen committed
1424
    # force one of token_ids to be generated by setting prob of all other tokens to 0.
patrickvonplaten's avatar
patrickvonplaten committed
1425
    def _force_token_ids_generation(self, scores, token_ids):
Patrick von Platen's avatar
Patrick von Platen committed
1426
1427
1428
1429
1430
1431
1432
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        all_but_token_ids_mask = torch.tensor(
            [x for x in range(self.config.vocab_size) if x not in token_ids],
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
patrickvonplaten's avatar
patrickvonplaten committed
1433
        assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
1434
        scores[:, all_but_token_ids_mask] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1435

Sam Shleifer's avatar
Sam Shleifer committed
1436
1437
    @staticmethod
    def _reorder_cache(past, beam_idx):
1438
        return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
Sam Shleifer's avatar
Sam Shleifer committed
1439

thomwolf's avatar
thomwolf committed
1440

1441
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
Patrick von Platen's avatar
Patrick von Platen committed
1442
    # Copied from fairseq for no_repeat_ngram in beam_search"""
patrickvonplaten's avatar
patrickvonplaten committed
1443
    if cur_len + 1 < no_repeat_ngram_size:
Patrick von Platen's avatar
Patrick von Platen committed
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
patrickvonplaten's avatar
patrickvonplaten committed
1456
        start_idx = cur_len + 1 - no_repeat_ngram_size
1457
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
Patrick von Platen's avatar
Patrick von Platen committed
1458
1459
1460
1461
1462
1463
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
    banned_tokens = []

    def _tokens_match(prev_tokens, tokens):
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        if len(tokens) > len(prev_input_ids):
            # if bad word tokens are longer then prev input_ids they can't be equal
            return False

        if prev_tokens[-len(tokens) :] == tokens:
            # if tokens match
            return True
        else:
            return False

    for prev_input_ids_slice in prev_input_ids:
        banned_tokens_slice = []

        for banned_token_seq in bad_words_ids:
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
                bad_words_ids
            )

            if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
                # if tokens do not match continue
                continue

            banned_tokens_slice.append(banned_token_seq[-1])

        banned_tokens.append(banned_tokens_slice)

    return banned_tokens


1500
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
thomwolf's avatar
thomwolf committed
1501
1502
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
thomwolf's avatar
thomwolf committed
1503
            logits: logits distribution shape (batch size, vocabulary size)
1504
1505
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
thomwolf's avatar
thomwolf committed
1506
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
thomwolf's avatar
thomwolf committed
1507
            Make sure we keep at least min_tokens_to_keep per batch example in the output
thomwolf's avatar
thomwolf committed
1508
1509
1510
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
thomwolf's avatar
thomwolf committed
1511
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
thomwolf's avatar
thomwolf committed
1512
1513
1514
1515
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

1516
    if top_p < 1.0:
thomwolf's avatar
thomwolf committed
1517
1518
1519
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

thomwolf's avatar
thomwolf committed
1520
        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
thomwolf's avatar
thomwolf committed
1521
        sorted_indices_to_remove = cumulative_probs > top_p
thomwolf's avatar
thomwolf committed
1522
1523
1524
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
thomwolf's avatar
thomwolf committed
1525
1526
1527
1528
1529
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
1530
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
thomwolf's avatar
thomwolf committed
1531
1532
        logits[indices_to_remove] = filter_value
    return logits
thomwolf's avatar
thomwolf committed
1533
1534
1535


class BeamHypotheses(object):
1536
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
thomwolf's avatar
thomwolf committed
1537
1538
1539
1540
1541
1542
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
1543
1544
        self.num_beams = num_beams
        self.beams = []
thomwolf's avatar
thomwolf committed
1545
1546
1547
1548
1549
1550
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
1551
        return len(self.beams)
thomwolf's avatar
thomwolf committed
1552

thomwolf's avatar
thomwolf committed
1553
1554
1555
1556
1557
    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
1558
1559
1560
1561
1562
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
thomwolf's avatar
thomwolf committed
1563
1564
1565
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
thomwolf's avatar
thomwolf committed
1566

Sam Shleifer's avatar
Sam Shleifer committed
1567
    def is_done(self, best_sum_logprobs, cur_len=None):
thomwolf's avatar
thomwolf committed
1568
1569
1570
1571
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
Sam Shleifer's avatar
Sam Shleifer committed
1572

1573
        if len(self) < self.num_beams:
thomwolf's avatar
thomwolf committed
1574
1575
1576
1577
            return False
        elif self.early_stopping:
            return True
        else:
Sam Shleifer's avatar
Sam Shleifer committed
1578
1579
1580
1581
1582
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret
thomwolf's avatar
thomwolf committed
1583
1584


thomwolf's avatar
thomwolf committed
1585
1586
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
thomwolf's avatar
thomwolf committed
1587
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
thomwolf's avatar
thomwolf committed
1588
1589
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
1590
        super().__init__()
thomwolf's avatar
thomwolf committed
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


thomwolf's avatar
thomwolf committed
1604
1605
class PoolerStartLogits(nn.Module):
    """ Compute SQuAD start_logits from sequence hidden states. """
1606

thomwolf's avatar
thomwolf committed
1607
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1608
        super().__init__()
thomwolf's avatar
thomwolf committed
1609
1610
1611
1612
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, p_mask=None):
        """ Args:
1613
1614
1615
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
                invalid position mask such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1616
        """
thomwolf's avatar
thomwolf committed
1617
1618
1619
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
1620
1621
1622
1623
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1624
1625
1626
1627
1628
1629
1630

        return x


class PoolerEndLogits(nn.Module):
    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
    """
1631

thomwolf's avatar
thomwolf committed
1632
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1633
        super().__init__()
thomwolf's avatar
thomwolf committed
1634
1635
1636
1637
1638
1639
1640
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
        """ Args:
1641
1642
1643
1644
1645
1646
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
1647
                position of the first token for the labeled span:
1648
1649
1650
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1651
        """
1652
1653
1654
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1655
        if start_positions is not None:
1656
            slen, hsz = hidden_states.shape[-2:]
1657
1658
1659
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)
thomwolf's avatar
thomwolf committed
1660
1661
1662
1663
1664
1665
1666

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
1667
1668
1669
1670
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1671
1672
1673
1674
1675
1676

        return x


class PoolerAnswerClass(nn.Module):
    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
1677

thomwolf's avatar
thomwolf committed
1678
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1679
        super().__init__()
thomwolf's avatar
thomwolf committed
1680
1681
1682
1683
1684
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
thomwolf's avatar
thomwolf committed
1700
        """
1701
        hsz = hidden_states.shape[-1]
1702
1703
1704
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1705
        if start_positions is not None:
1706
1707
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1708
1709

        if cls_index is not None:
1710
1711
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1712
        else:
1713
            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1714
1715
1716
1717
1718
1719
1720
1721
1722

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x


class SQuADHead(nn.Module):
1723
1724
1725
    r""" A SQuAD head inspired by XLNet.

    Parameters:
1726
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745

    Inputs:
        **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
            hidden states of sequence tokens
        **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the first token for the labeled span.
        **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the last token for the labeled span.
        **cls_index**: torch.LongTensor of shape ``(batch_size,)``
            position of the CLS token. If None, take the last token.
        **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
            Whether the question has a possible answer in the paragraph or not.
        **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
            Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
            1.0 means token should be masked.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
thomwolf's avatar
thomwolf committed
1746
        **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1747
1748
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1749
        **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1750
1751
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
            Indices for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1752
        **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1753
1754
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1755
        **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1756
1757
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1758
        **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1759
1760
            ``torch.FloatTensor`` of shape ``(batch_size,)``
            Log probabilities for the ``is_impossible`` label of the answers.
thomwolf's avatar
thomwolf committed
1761
    """
1762

thomwolf's avatar
thomwolf committed
1763
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1764
        super().__init__()
thomwolf's avatar
thomwolf committed
1765
1766
1767
1768
1769
1770
1771
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

1772
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
1773
        self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1774
    ):
thomwolf's avatar
thomwolf committed
1775
1776
        outputs = ()

thomwolf's avatar
thomwolf committed
1777
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
thomwolf's avatar
thomwolf committed
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index, is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5
1801
1802

            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1803
1804
1805
1806

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
            start_log_probs = F.softmax(start_logits, dim=-1)  # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(
                start_log_probs, self.start_n_top, dim=-1
            )  # shape (bsz, start_n_top)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, start_n_top, hsz)
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states
            )  # shape (bsz, slen, start_n_top, hsz)
thomwolf's avatar
thomwolf committed
1819
1820
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1821
            end_log_probs = F.softmax(end_logits, dim=1)  # shape (bsz, slen, start_n_top)
thomwolf's avatar
thomwolf committed
1822

1823
1824
1825
            end_top_log_probs, end_top_index = torch.topk(
                end_log_probs, self.end_n_top, dim=1
            )  # shape (bsz, end_n_top, start_n_top)
thomwolf's avatar
thomwolf committed
1826
1827
1828
1829
1830
1831
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)

Patrick von Platen's avatar
Patrick von Platen committed
1832
            outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
thomwolf's avatar
thomwolf committed
1833
1834

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1835
        # or (if labels are provided) (total_loss,)
thomwolf's avatar
thomwolf committed
1836
1837
1838
1839
        return outputs


class SequenceSummary(nn.Module):
thomwolf's avatar
thomwolf committed
1840
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
thomwolf's avatar
thomwolf committed
1841
1842
1843
1844
1845
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
thomwolf's avatar
thomwolf committed
1846
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
thomwolf's avatar
thomwolf committed
1847
1848
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
1849
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1850
            summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1851
1852
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
thomwolf's avatar
thomwolf committed
1853
    """
1854

1855
    def __init__(self, config: PretrainedConfig):
Julien Chaumond's avatar
Julien Chaumond committed
1856
        super().__init__()
thomwolf's avatar
thomwolf committed
1857

1858
        self.summary_type = getattr(config, "summary_type", "last")
1859
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1860
1861
1862
1863
1864
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

thomwolf's avatar
thomwolf committed
1865
        self.summary = Identity()
1866
1867
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1868
                num_classes = config.num_labels
thomwolf's avatar
thomwolf committed
1869
1870
1871
1872
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

1873
1874
1875
1876
        activation_string = getattr(config, "summary_activation", None)
        self.activation = (
            get_activation(activation_string) if activation_string else Identity()
        )  # type: typing.Callable
thomwolf's avatar
thomwolf committed
1877

thomwolf's avatar
thomwolf committed
1878
        self.first_dropout = Identity()
1879
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1880
1881
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

thomwolf's avatar
thomwolf committed
1882
        self.last_dropout = Identity()
1883
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1884
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
thomwolf's avatar
thomwolf committed
1885

thomwolf's avatar
thomwolf committed
1886
    def forward(self, hidden_states, cls_index=None):
1887
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
thomwolf's avatar
thomwolf committed
1888
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
thomwolf's avatar
thomwolf committed
1889
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
thomwolf's avatar
thomwolf committed
1890
                if summary_type == 'cls_index' and cls_index is None:
thomwolf's avatar
thomwolf committed
1891
1892
                    we take the last token of the sequence as classification token
        """
1893
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
1894
            output = hidden_states[:, -1]
1895
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
1896
            output = hidden_states[:, 0]
1897
        elif self.summary_type == "mean":
thomwolf's avatar
thomwolf committed
1898
            output = hidden_states.mean(dim=1)
1899
        elif self.summary_type == "cls_index":
thomwolf's avatar
thomwolf committed
1900
            if cls_index is None:
Patrick von Platen's avatar
Patrick von Platen committed
1901
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
thomwolf's avatar
thomwolf committed
1902
            else:
thomwolf's avatar
thomwolf committed
1903
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1904
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1905
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1906
1907
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1908
1909
            raise NotImplementedError

1910
        output = self.first_dropout(output)
thomwolf's avatar
thomwolf committed
1911
1912
        output = self.summary(output)
        output = self.activation(output)
1913
        output = self.last_dropout(output)
thomwolf's avatar
thomwolf committed
1914
1915
1916
1917

        return output


Sam Shleifer's avatar
Sam Shleifer committed
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.

    :param torch.Tensor x:
    :return torch.Tensor:
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = input_ids.ne(padding_idx).int()
    incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
    return incremental_indicies.long() + padding_idx


1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
def prune_linear_layer(layer, index, dim=0):
    """ Prune a linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991


def prune_layer(layer, index, dim=None):
    """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    if isinstance(layer, nn.Linear):
        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
    elif isinstance(layer, Conv1D):
        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
    else:
        raise ValueError("Can't prune layer of class {}".format(layer.__class__))