"vscode:/vscode.git/clone" did not exist on "5090ea3f68de825e5a51d0cb3c0fa75c12af40b0"
modeling_utils.py 92.4 KB
Newer Older
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

import logging
import os
20
import typing
21
22
23

import torch
from torch import nn
24
25
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
26

27
from .activations import get_activation
28
from .configuration_utils import PretrainedConfig
29
from .file_utils import (
Aymeric Augustin's avatar
Aymeric Augustin committed
30
    DUMMY_INPUTS,
31
32
33
34
35
36
37
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    WEIGHTS_NAME,
    cached_path,
    hf_bucket_url,
    is_remote_url,
)
38

Aymeric Augustin's avatar
Aymeric Augustin committed
39

40
41
logger = logging.getLogger(__name__)

42

thomwolf's avatar
thomwolf committed
43
44
45
46
47
48
49
try:
    from torch.nn import Identity
except ImportError:
    # Older PyTorch compatibility
    class Identity(nn.Module):
        r"""A placeholder identity operator that is argument-insensitive.
        """
50

thomwolf's avatar
thomwolf committed
51
        def __init__(self, *args, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
52
            super().__init__()
thomwolf's avatar
thomwolf committed
53
54
55
56

        def forward(self, input):
            return input

57

58
class ModuleUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
59
60
61
62
63
64
65
66
67
68
69
    """
    A few utilities for torch.nn.Modules, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the module.
        """
        params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
        return sum(p.numel() for p in params)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @staticmethod
    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_pre_forward = mem.rss
        return None

    @staticmethod
    def _hook_rss_memory_post_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_post_forward = mem.rss
        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
        return None

    def add_memory_hooks(self):
        """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
            Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
        """
        for module in self.modules():
            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
            module.register_forward_hook(self._hook_rss_memory_post_forward)
        self.reset_memory_hooks_state()

    def reset_memory_hooks_state(self):
        for module in self.modules():
            module.mem_rss_diff = 0
            module.mem_rss_post_forward = 0
            module.mem_rss_pre_forward = 0

111
112
113
114
    @property
    def device(self):
        return next(self.parameters()).device

Julien Chaumond's avatar
Julien Chaumond committed
115

116
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
117
118
    r""" Base class for all models.

119
        :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
120
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
121
122

        Class attributes (overridden by derived classes):
123
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
124
125
126
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

127
128
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
129
130
131
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
132
    """
133
    config_class = None
134
135
136
    pretrained_model_archive_map = {}
    base_model_prefix = ""

137
138
139
140
141
142
143
    @property
    def dummy_inputs(self):
        """ Dummy inputs to do a forward pass in the network.

        Returns:
            torch.Tensor with dummy inputs
        """
144
        return {"input_ids": torch.tensor(DUMMY_INPUTS)}
145

146
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
147
        super().__init__()
148
149
150
151
152
153
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
154
155
                )
            )
thomwolf's avatar
thomwolf committed
156
        # Save config in model
157
158
        self.config = config

159
160
161
    @property
    def base_model(self):
        return getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
162

thomwolf's avatar
thomwolf committed
163
    def get_input_embeddings(self):
164
165
166
167
168
169
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
170
        """
171
        base_model = getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
172
173
174
175
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
176

thomwolf's avatar
thomwolf committed
177
    def set_input_embeddings(self, value):
178
179
180
181
182
183
        """
        Set model's input embeddings

        Args:
            value (:obj:`nn.Module`):
                A module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
184
185
186
187
188
189
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
190

thomwolf's avatar
thomwolf committed
191
    def get_output_embeddings(self):
192
193
194
195
196
197
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping hidden states to vocabulary.
thomwolf's avatar
thomwolf committed
198
        """
199
        return None  # Overwrite for models with output embeddings
thomwolf's avatar
thomwolf committed
200

201
    def tie_weights(self):
202
203
204
205
        """
        Tie the weights between the input embeddings and the output embeddings.
        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
        the weights instead.
thomwolf's avatar
thomwolf committed
206
        """
thomwolf's avatar
thomwolf committed
207
208
209
        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
thomwolf's avatar
thomwolf committed
210

211
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
thomwolf's avatar
thomwolf committed
212
213
214
        """ Tie or clone module weights depending of weither we are using TorchScript or not
        """
        if self.config.torchscript:
215
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
thomwolf's avatar
thomwolf committed
216
        else:
217
            output_embeddings.weight = input_embeddings.weight
thomwolf's avatar
thomwolf committed
218

Sam Shleifer's avatar
Sam Shleifer committed
219
        if getattr(output_embeddings, "bias", None) is not None:
220
221
            output_embeddings.bias.data = torch.nn.functional.pad(
                output_embeddings.bias.data,
Patrick von Platen's avatar
Patrick von Platen committed
222
                (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
223
224
                "constant",
                0,
225
            )
226
        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
227
            output_embeddings.out_features = input_embeddings.num_embeddings
228

thomwolf's avatar
thomwolf committed
229
230
    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
231
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
thomwolf's avatar
thomwolf committed
232

233
234
235
        Arguments:

            new_num_tokens: (`optional`) int:
236
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
237
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
thomwolf's avatar
thomwolf committed
238

thomwolf's avatar
thomwolf committed
239
        Return: ``torch.nn.Embeddings``
240
            Pointer to the input tokens Embeddings Module of the model
thomwolf's avatar
thomwolf committed
241
242
        """
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
thomwolf's avatar
thomwolf committed
243
244
245
        model_embeds = base_model._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds
thomwolf's avatar
thomwolf committed
246
247
248
249
250
251

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens

        # Tie weights again if needed
252
        self.tie_weights()
thomwolf's avatar
thomwolf committed
253

thomwolf's avatar
thomwolf committed
254
255
        return model_embeds

256
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
257
258
259
260
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_input_embeddings(new_embeddings)
        return self.get_input_embeddings()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
        """ Build a resized Embedding Module from a provided token Embedding Module.
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
        Return: ``torch.nn.Embeddings``
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
        if new_num_tokens is None:
            return old_embeddings

        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        if old_num_tokens == new_num_tokens:
            return old_embeddings

        # Build new embeddings
        new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        new_embeddings.to(old_embeddings.weight.device)

        # initialize all new embeddings (in particular added tokens)
        self._init_weights(new_embeddings)

290
        # Copy token embeddings from the previous weights
291
292
293
294
295
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        return new_embeddings

296
297
298
299
300
301
302
303
304
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

305
306
307
        # Tie weights if needed
        self.tie_weights()

thomwolf's avatar
thomwolf committed
308
309
    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.
310
311
312
313

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
314
                E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
315
        """
316
        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
317
        for layer, heads in heads_to_prune.items():
318
319
320
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON

321
        self.base_model._prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
322

323
    def save_pretrained(self, save_directory):
324
        """ Save a model and its configuration file to a directory, so that it
325
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
326
        """
327
328
329
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
330

Julien Chaumond's avatar
Julien Chaumond committed
331
        # Only save the model itself if we are using distributed training
332
        model_to_save = self.module if hasattr(self, "module") else self
333

Julien Chaumond's avatar
Julien Chaumond committed
334
335
336
        # Attach architecture to the config
        model_to_save.config.architectures = [model_to_save.__class__.__name__]

thomwolf's avatar
thomwolf committed
337
338
339
        # Save configuration file
        model_to_save.config.save_pretrained(save_directory)

340
341
342
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
thomwolf's avatar
thomwolf committed
343
        logger.info("Model weights saved in {}".format(output_model_file))
344

345
    @classmethod
346
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
347
348
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

349
350
351
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

352
353
354
355
356
        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

357
358
        Parameters:
            pretrained_model_name_or_path: either:
Lysandre's avatar
Fixes  
Lysandre committed
359
360
361
362
363
              - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
              - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
              - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
              - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
              - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
364
365
366
367

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

368
            config: (`optional`) one of:
Lysandre's avatar
Fixes  
Lysandre committed
369
370
                - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
371
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
Lysandre's avatar
Fixes  
Lysandre committed
372
373
374
                    - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                    - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                    - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
375
376
377

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
378
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
379
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
380
381

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
382
383
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
384

385
386
387
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

388
389
390
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

391
392
393
394
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

395
            output_loading_info: (`optional`) boolean:
thomwolf's avatar
thomwolf committed
396
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
397
398
399
400
401

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
402
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
403
404

        Examples::
thomwolf's avatar
thomwolf committed
405

Lysandre's avatar
Lysandre committed
406
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
407
408
409
410
411
412
413
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
414

415
        """
416
417
418
419
420
421
422
423
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
424
        local_files_only = kwargs.pop("local_files_only", False)
thomwolf's avatar
thomwolf committed
425

426
427
428
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
429
            config, model_kwargs = cls.config_class.from_pretrained(
430
431
432
433
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
434
                force_download=force_download,
435
                resume_download=resume_download,
436
                proxies=proxies,
437
                local_files_only=local_files_only,
438
                **kwargs,
439
440
441
            )
        else:
            model_kwargs = kwargs
442

thomwolf's avatar
thomwolf committed
443
        # Load model
thomwolf's avatar
thomwolf committed
444
        if pretrained_model_name_or_path is not None:
445
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
446
447
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
448
449
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
thomwolf's avatar
thomwolf committed
450
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
thomwolf's avatar
thomwolf committed
451
452
453
454
455
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
456
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
thomwolf's avatar
thomwolf committed
457
                else:
458
459
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False".format(
Patrick von Platen's avatar
Patrick von Platen committed
460
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
Patrick von Platen's avatar
Patrick von Platen committed
461
                            pretrained_model_name_or_path,
462
463
                        )
                    )
464
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
465
                archive_file = pretrained_model_name_or_path
466
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
467
468
469
470
471
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index"
                )
472
                archive_file = pretrained_model_name_or_path + ".index"
473
            else:
thomwolf's avatar
thomwolf committed
474
                archive_file = hf_bucket_url(
Patrick von Platen's avatar
Patrick von Platen committed
475
                    pretrained_model_name_or_path, postfix=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
thomwolf's avatar
thomwolf committed
476
                )
477

thomwolf's avatar
thomwolf committed
478
479
            # redirect to the cache, if necessary
            try:
480
481
482
483
484
485
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
486
                    local_files_only=local_files_only,
487
                )
thomwolf's avatar
thomwolf committed
488
            except EnvironmentError:
thomwolf's avatar
thomwolf committed
489
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
490
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
thomwolf's avatar
thomwolf committed
491
                else:
492
493
494
                    msg = (
                        "Model name '{}' was not found in model name list ({}). "
                        "We assumed '{}' was a path or url to model weight files named one of {} but "
thomwolf's avatar
thomwolf committed
495
                        "couldn't find any such file at this path or url.".format(
thomwolf's avatar
thomwolf committed
496
                            pretrained_model_name_or_path,
497
                            ", ".join(cls.pretrained_model_archive_map.keys()),
thomwolf's avatar
thomwolf committed
498
                            archive_file,
499
500
501
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
                        )
                    )
thomwolf's avatar
thomwolf committed
502
503
                raise EnvironmentError(msg)

thomwolf's avatar
thomwolf committed
504
505
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
506
            else:
507
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
508
        else:
thomwolf's avatar
thomwolf committed
509
            resolved_archive_file = None
510
511

        # Instantiate model.
512
        model = cls(config, *model_args, **model_kwargs)
thomwolf's avatar
thomwolf committed
513

514
        if state_dict is None and not from_tf:
515
            try:
516
                state_dict = torch.load(resolved_archive_file, map_location="cpu")
517
            except Exception:
518
519
520
521
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )
522

523
524
525
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
526
527

        if from_tf:
528
            if resolved_archive_file.endswith(".index"):
529
530
531
532
533
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
534
                    from transformers import load_tf2_checkpoint_in_pytorch_model
535

536
                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
537
                except ImportError:
538
539
540
541
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
542
                    raise
543
544
545
546
547
548
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
549
550
551
552
                if "gamma" in key:
                    new_key = key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = key.replace("beta", "bias")
553
554
555
556
557
558
559
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
560
            metadata = getattr(state_dict, "_metadata", None)
561
562
563
564
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

565
566
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
Julien Chaumond's avatar
Julien Chaumond committed
567
            def load(module: nn.Module, prefix=""):
568
569
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
Patrick von Platen's avatar
Patrick von Platen committed
570
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
571
                )
572
573
                for name, child in module._modules.items():
                    if child is not None:
574
                        load(child, prefix + name + ".")
575
576

            # Make sure we are able to load base models as well as derived models (with heads)
577
            start_prefix = ""
578
            model_to_load = model
579
580
581
582
583
584
585
            if not hasattr(model, cls.base_model_prefix) and any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
                start_prefix = cls.base_model_prefix + "."
            if hasattr(model, cls.base_model_prefix) and not any(
                s.startswith(cls.base_model_prefix) for s in state_dict.keys()
            ):
586
587
588
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
589
590
591
592
593
594
595
596
597

            if model.__class__.__name__ != model_to_load.__class__.__name__:
                base_model_state_dict = model_to_load.state_dict().keys()
                head_model_state_dict_without_base_prefix = [
                    key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
                ]

                missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

598
            if len(missing_keys) > 0:
599
600
601
602
603
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".format(
                        model.__class__.__name__, missing_keys
                    )
                )
604
            if len(unexpected_keys) > 0:
605
606
607
608
609
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys
                    )
                )
610
            if len(error_msgs) > 0:
611
612
613
614
615
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)
                    )
                )
616
        model.tie_weights()  # make sure token embedding weights are still tied if needed
617

618
619
620
        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

thomwolf's avatar
thomwolf committed
621
        if output_loading_info:
622
623
624
625
626
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
thomwolf's avatar
thomwolf committed
627
628
            return model, loading_info

629
630
        return model

thomwolf's avatar
thomwolf committed
631
632
633
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

patrickvonplaten's avatar
patrickvonplaten committed
634
635
636
    def prepare_scores_for_generation(self, scores, **kwargs):
        return scores

637
    def _do_output_past(self, outputs):
Sam Shleifer's avatar
Sam Shleifer committed
638
639
640
641
642
643
        """During generation, decide whether to pass the `past` variable to the next forward pass."""
        has_output_past = getattr(self.config, "output_past", False)
        mem_len = getattr(self.config, "mem_len", 0)
        if len(outputs) <= 1:
            return False
        if mem_len > 0 or has_output_past:
644
            return True
645
646
        return False

Sam Shleifer's avatar
Sam Shleifer committed
647
648
649
650
651
652
653
654
655
656
    def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty

thomwolf's avatar
thomwolf committed
657
    @torch.no_grad()
658
659
660
661
    def generate(
        self,
        input_ids=None,
        max_length=None,
662
        min_length=None,
663
664
        do_sample=None,
        early_stopping=None,
665
666
667
668
669
670
671
        num_beams=None,
        temperature=None,
        top_k=None,
        top_p=None,
        repetition_penalty=None,
        bos_token_id=None,
        pad_token_id=None,
672
        eos_token_id=None,
673
        length_penalty=None,
Patrick von Platen's avatar
Patrick von Platen committed
674
        no_repeat_ngram_size=None,
675
        num_return_sequences=None,
Patrick von Platen's avatar
Patrick von Platen committed
676
        attention_mask=None,
677
        decoder_start_token_id=None,
678
    ):
679
        r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
thomwolf's avatar
thomwolf committed
680

681
682
683
684
685
686
687
        Adapted in part from `Facebook's XLM beam search code`_.

        .. _`Facebook's XLM beam search code`:
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


        Parameters:
thomwolf's avatar
thomwolf committed
688

689
            input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
thomwolf's avatar
thomwolf committed
690
                The sequence used as a prompt for the generation. If `None` the method initializes
691
692
693
                it as an empty `torch.LongTensor` of shape `(1,)`.

            max_length: (`optional`) int
694
695
696
697
                The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.

            min_length: (`optional`) int
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
698
699

            do_sample: (`optional`) bool
700
701
702
703
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

            early_stopping: (`optional`) bool
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
704
705
706
707
708

            num_beams: (`optional`) int
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

            temperature: (`optional`) float
Sam Shleifer's avatar
Sam Shleifer committed
709
                The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
710
711

            top_k: (`optional`) int
thomwolf's avatar
thomwolf committed
712
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
713
714

            top_p: (`optional`) float
thomwolf's avatar
thomwolf committed
715
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
716
717
718
719

            repetition_penalty: (`optional`) float
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

720
721
722
            pad_token_id: (`optional`) int
                Padding token. Default to specicic model pad_token_id or None if it does not exist.

723
            bos_token_id: (`optional`) int
724
725
726
727
                BOS token. Defaults to bos_token_id as defined in the models config.

            pad_token_id: (`optional`) int
                Pad token. Defaults to pad_token_id as defined in the models config.
728
729

            eos_token_ids: (`optional`) int or list of int
730
731
                End of sequence token or list of tokens to stop the generation. Default to eos_token_ids as defined in the models config.

732
            length_penalty: (`optional`) float
thomwolf's avatar
thomwolf committed
733
                Exponential penalty to the length. Default to 1.
734

735
736
737
            no_repeat_ngram_size: (`optional`) int
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.

738
739
740
            num_return_sequences: (`optional`) int
                The number of independently computed returned sequences for each element in the batch. Default to 1.

741
742
743
744
745
746
747
748
749
750
751
752
            attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
                Mask to avoid performing attention on padding token indices.
                Mask values selected in ``[0, 1]``:
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
                Defaults to `None`.

            `What are attention masks? <../glossary.html#attention-mask>`__

            decoder_start_token_id=None: (`optional`) int
                If an encoder-decoder model starts decoding with a different token than BOS.
                Defaults to `None` and is changed to `BOS` later.

753
754
755
756
757
        Return:

            output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

758
759
760
761
        Examples::

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
762
            outputs = model.generate(max_length=40)  # do greedy decoding
763
764
765
766
767
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
768
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
769
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
770
            for i in range(3): #  3 output sequences were generated
771
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
772
773
774
775

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
776
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
777
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
778
779
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
780
781
782
783

            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
784
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
785
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
786
787
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

thomwolf's avatar
thomwolf committed
788
789
790
791
        """

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
792
793
            raise AttributeError(
                "You tried to generate sequences with a model that does not have a LM Head."
794
                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
795
            )
thomwolf's avatar
thomwolf committed
796

797
        max_length = max_length if max_length is not None else self.config.max_length
798
        min_length = min_length if min_length is not None else self.config.min_length
799
        do_sample = do_sample if do_sample is not None else self.config.do_sample
Patrick von Platen's avatar
Patrick von Platen committed
800
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
801
802
803
804
805
806
807
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
808
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
809
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
Patrick von Platen's avatar
Patrick von Platen committed
810
811
812
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
813
814
815
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )
816
817
818
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
thomwolf's avatar
thomwolf committed
819
820
821

        if input_ids is not None:
            batch_size = input_ids.shape[0]  # overriden by the input batch_size
thomwolf's avatar
thomwolf committed
822
823
        else:
            batch_size = 1
thomwolf's avatar
thomwolf committed
824

Sam Shleifer's avatar
Sam Shleifer committed
825
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
826
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
thomwolf's avatar
thomwolf committed
827
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
Patrick von Platen's avatar
Patrick von Platen committed
828
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
Sam Shleifer's avatar
Sam Shleifer committed
829
830
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
        assert temperature > 0, "`temperature` should be strictly positive."
831
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
thomwolf's avatar
thomwolf committed
832
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
thomwolf's avatar
thomwolf committed
833
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
834
835
836
837
838
839
        assert input_ids is not None or (
            isinstance(bos_token_id, int) and bos_token_id >= 0
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
        assert pad_token_id is None or (
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
        ), "`pad_token_id` should be a positive integer."
840
841
842
        assert (eos_token_id is None) or (
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
        ), "`eos_token_id` should be a positive integer."
Sam Shleifer's avatar
Sam Shleifer committed
843
        assert length_penalty > 0, "`length_penalty` should be strictly positive."
Patrick von Platen's avatar
Patrick von Platen committed
844
845
846
        assert (
            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
        ), "`no_repeat_ngram_size` should be a positive integer."
847
848
        assert (
            isinstance(num_return_sequences, int) and num_return_sequences > 0
Sam Shleifer's avatar
Sam Shleifer committed
849
        ), "`num_return_sequences` should be a strictly positive integer."
thomwolf's avatar
thomwolf committed
850
851

        if input_ids is None:
852
853
854
855
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
                "you should either supply a context to complete as `input_ids` input "
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
            )
856
            input_ids = torch.full(
Patrick von Platen's avatar
Patrick von Platen committed
857
                (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
858
            )
thomwolf's avatar
thomwolf committed
859
        else:
860
            assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
thomwolf's avatar
thomwolf committed
861

862
        # not allow to duplicate outputs when greedy decoding
863
864
865
866
867
868
869
870
871
872
873
874
875
        if do_sample is False:
            if num_beams == 1:
                # no_beam_search greedy generation conditions
                assert (
                    num_return_sequences == 1
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

            else:
                # beam_search greedy generation conditions
                assert (
                    num_beams >= num_return_sequences
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

Patrick von Platen's avatar
Patrick von Platen committed
876
        # create attention mask if necessary
patrickvonplaten's avatar
patrickvonplaten committed
877
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
Patrick von Platen's avatar
Patrick von Platen committed
878
879
880
881
882
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
            attention_mask = input_ids.ne(pad_token_id).long()
        elif attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)

883
        # set pad_token_id to eos_token_id if not set. Important that this is done after
Patrick von Platen's avatar
Patrick von Platen committed
884
        # attention_mask is created
885
        if pad_token_id is None and eos_token_id is not None:
886
            logger.warning(
887
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
888
            )
889
            pad_token_id = eos_token_id
890

thomwolf's avatar
thomwolf committed
891
892
893
        # current position and vocab size
        vocab_size = self.config.vocab_size

894
895
        # set effective batch size and effective batch multiplier according to do_sample
        if do_sample:
thomwolf's avatar
thomwolf committed
896
            effective_batch_size = batch_size * num_return_sequences
897
            effective_batch_mult = num_return_sequences
thomwolf's avatar
thomwolf committed
898
899
        else:
            effective_batch_size = batch_size
900
901
            effective_batch_mult = 1

902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        if self.config.is_encoder_decoder:
            if decoder_start_token_id is None:
                decoder_start_token_id = bos_token_id

            assert (
                decoder_start_token_id is not None
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

            # get encoder and store encoder outputs
            encoder = self.get_encoder()

            encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

917
918
919
920
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
        if num_return_sequences > 1 or num_beams > 1:
            input_ids_len = input_ids.shape[-1]
            input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
921
922
923
            attention_mask = attention_mask.unsqueeze(1).expand(
                batch_size, effective_batch_mult * num_beams, input_ids_len
            )
Patrick von Platen's avatar
Patrick von Platen committed
924

925
926
927
            input_ids = input_ids.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
Patrick von Platen's avatar
Patrick von Platen committed
928
929
930
            attention_mask = attention_mask.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
931

Patrick von Platen's avatar
Patrick von Platen committed
932
        if self.config.is_encoder_decoder:
933
            # create empty decoder_input_ids
Patrick von Platen's avatar
Patrick von Platen committed
934
935
            input_ids = torch.full(
                (effective_batch_size * num_beams, 1),
936
                decoder_start_token_id,
Patrick von Platen's avatar
Patrick von Platen committed
937
938
939
                dtype=torch.long,
                device=next(self.parameters()).device,
            )
940
            cur_len = 1
941
942
943
944
945
946
947
948
949
950
951
952
            batch_idx = self.encoder_outputs_batch_dim_idx
            assert (
                batch_size == encoder_outputs[0].shape[batch_idx]
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
            expanded_idx = (
                torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, num_beams * effective_batch_mult)
                .view(-1)
                .to(input_ids.device)
            )
            encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
Patrick von Platen's avatar
Patrick von Platen committed
953
        else:
954
            encoder_outputs = None
Patrick von Platen's avatar
Patrick von Platen committed
955
956
            cur_len = input_ids.shape[-1]

thomwolf's avatar
thomwolf committed
957
        if num_beams > 1:
958
959
            output = self._generate_beam_search(
                input_ids,
960
961
962
963
964
965
966
967
968
969
970
971
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                early_stopping=early_stopping,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bos_token_id=bos_token_id,
                pad_token_id=pad_token_id,
972
                decoder_start_token_id=decoder_start_token_id,
973
                eos_token_id=eos_token_id,
974
975
976
977
978
                batch_size=effective_batch_size,
                num_return_sequences=num_return_sequences,
                length_penalty=length_penalty,
                num_beams=num_beams,
                vocab_size=vocab_size,
979
                encoder_outputs=encoder_outputs,
980
                attention_mask=attention_mask,
981
            )
thomwolf's avatar
thomwolf committed
982
        else:
983
984
            output = self._generate_no_beam_search(
                input_ids,
985
986
987
988
989
990
991
992
993
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
994
                bos_token_id=bos_token_id,
995
                pad_token_id=pad_token_id,
996
                decoder_start_token_id=decoder_start_token_id,
997
                eos_token_id=eos_token_id,
998
                batch_size=effective_batch_size,
999
                encoder_outputs=encoder_outputs,
1000
                attention_mask=attention_mask,
1001
            )
thomwolf's avatar
thomwolf committed
1002
1003

        return output
thomwolf's avatar
thomwolf committed
1004

1005
1006
1007
1008
1009
    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1010
        min_length,
1011
1012
1013
1014
1015
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1016
        no_repeat_ngram_size,
1017
        bos_token_id,
1018
        pad_token_id,
1019
        eos_token_id,
1020
        decoder_start_token_id,
1021
        batch_size,
1022
        encoder_outputs,
1023
        attention_mask,
1024
    ):
thomwolf's avatar
thomwolf committed
1025
        """ Generate sequences for each example without beam search (num_beams == 1).
1026
1027
            All returned sequence are generated independantly.
        """
1028
        # length of generated sentences / unfinished sentences
thomwolf's avatar
thomwolf committed
1029
        unfinished_sents = input_ids.new(batch_size).fill_(1)
1030
        sent_lengths = input_ids.new(batch_size).fill_(max_length)
thomwolf's avatar
thomwolf committed
1031

1032
1033
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

thomwolf's avatar
thomwolf committed
1034
        while cur_len < max_length:
1035
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
Sam Shleifer's avatar
Sam Shleifer committed
1036

thomwolf's avatar
thomwolf committed
1037
1038
1039
            outputs = self(**model_inputs)
            next_token_logits = outputs[0][:, -1, :]

patrickvonplaten's avatar
patrickvonplaten committed
1040
            # if model has past, then set the past variable to speed up decoding
1041
            if self._do_output_past(outputs):
1042
1043
                past = outputs[1]

thomwolf's avatar
thomwolf committed
1044
1045
            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
Sam Shleifer's avatar
Sam Shleifer committed
1046
                self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
thomwolf's avatar
thomwolf committed
1047

Patrick von Platen's avatar
Patrick von Platen committed
1048
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1049
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
Patrick von Platen's avatar
Patrick von Platen committed
1050
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
patrickvonplaten's avatar
patrickvonplaten committed
1051
                banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
Patrick von Platen's avatar
Patrick von Platen committed
1052
                for batch_idx in range(batch_size):
1053
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1054

patrickvonplaten's avatar
patrickvonplaten committed
1055
            # set eos token prob to zero if min_length is not reached
1056
1057
            if eos_token_id is not None and cur_len < min_length:
                next_token_logits[:, eos_token_id] = -float("inf")
1058

thomwolf's avatar
thomwolf committed
1059
1060
            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
1061
                if temperature != 1.0:
thomwolf's avatar
thomwolf committed
1062
1063
1064
1065
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
1066
1067
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
thomwolf's avatar
thomwolf committed
1068
1069
            else:
                # Greedy decoding
1070
                next_token = torch.argmax(next_token_logits, dim=-1)
thomwolf's avatar
thomwolf committed
1071
1072

            # update generations and finished sentences
1073
1074
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
1075
1076
1077
1078
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

1079
            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
1080

1081
1082
1083
1084
1085
1086
1087
            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents.mul_((~eos_in_sents).long())
1088

thomwolf's avatar
thomwolf committed
1089
1090
1091
1092
            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

patrickvonplaten's avatar
patrickvonplaten committed
1093
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1094
            if self.config.is_encoder_decoder is False:
1095
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1096
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1097
                )
Patrick von Platen's avatar
Patrick von Platen committed
1098

1099
1100
            cur_len = cur_len + 1

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # if there are different sentences lengths in the batch, some batches have to be padded
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
            # finished sents are filled with pad_token
            decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
        else:
            decoded = input_ids

        for hypo_idx, hypo in enumerate(input_ids):
            decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
1111

1112
        return decoded
thomwolf's avatar
thomwolf committed
1113

1114
1115
1116
1117
1118
    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1119
        min_length,
1120
        do_sample,
Patrick von Platen's avatar
Patrick von Platen committed
1121
        early_stopping,
1122
1123
1124
1125
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1126
        no_repeat_ngram_size,
Patrick von Platen's avatar
Patrick von Platen committed
1127
        bos_token_id,
1128
        pad_token_id,
1129
        eos_token_id,
1130
        decoder_start_token_id,
1131
        batch_size,
1132
        num_return_sequences,
1133
1134
1135
        length_penalty,
        num_beams,
        vocab_size,
1136
        encoder_outputs,
Patrick von Platen's avatar
Patrick von Platen committed
1137
        attention_mask,
1138
    ):
thomwolf's avatar
thomwolf committed
1139
        """ Generate sequences for each example with beam search.
1140
        """
thomwolf's avatar
thomwolf committed
1141
1142

        # generated hypotheses
1143
        generated_hyps = [
1144
1145
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
            for _ in range(batch_size)
1146
        ]
thomwolf's avatar
thomwolf committed
1147
1148
1149

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
patrickvonplaten's avatar
patrickvonplaten committed
1150
1151

        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
Patrick von Platen's avatar
Patrick von Platen committed
1152
1153
        if do_sample is False:
            beam_scores[:, 1:] = -1e9
1154
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
thomwolf's avatar
thomwolf committed
1155
1156

        # cache compute states
1157
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models
thomwolf's avatar
thomwolf committed
1158
1159
1160
1161
1162

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
1163
            model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
1164
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
1165
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)
1166

patrickvonplaten's avatar
patrickvonplaten committed
1167
            # if model has past, then set the past variable to speed up decoding
1168
            if self._do_output_past(outputs):
1169
                past = outputs[1]
thomwolf's avatar
thomwolf committed
1170

1171
1172
            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
1173
                self.enforce_repetition_penalty_(
Patrick von Platen's avatar
Patrick von Platen committed
1174
                    next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
1175
                )
thomwolf's avatar
thomwolf committed
1176

patrickvonplaten's avatar
patrickvonplaten committed
1177
1178
1179
1180
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature

            scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
Patrick von Platen's avatar
Patrick von Platen committed
1181
            if self.config.is_encoder_decoder and do_sample is False:
1182
1183
                # TODO (PVP) still a bit hacky here - there might be a better solutino
                scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
patrickvonplaten's avatar
patrickvonplaten committed
1184
1185

            # set eos token prob to zero if min_length is not reached
1186
1187
            if eos_token_id is not None and cur_len < min_length:
                scores[:, eos_token_id] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1188

Patrick von Platen's avatar
Patrick von Platen committed
1189
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1190
1191
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                num_batch_hypotheses = batch_size * num_beams
Patrick von Platen's avatar
Patrick von Platen committed
1192
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1193
1194
1195
                banned_batch_tokens = calc_banned_tokens(
                    input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
                )
patrickvonplaten's avatar
patrickvonplaten committed
1196
                for i, banned_tokens in enumerate(banned_batch_tokens):
1197
                    scores[i, banned_tokens] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1198

1199
1200
1201
            assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
                scores.shape, (batch_size * num_beams, vocab_size)
            )
1202

1203
            if do_sample:
1204
                _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
1205
                # Top-p/top-k filtering
1206
1207
                _scores = top_k_top_p_filtering(
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
1208
                )  # (batch_size * num_beams, vocab_size)
1209
1210
1211
1212
1213
                # re-organize to group the beam together to sample from all beam_idxs
                _scores = _scores.contiguous().view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)

1214
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
1215
1216
                probs = F.softmax(_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
1217
                # Compute next scores
1218
                next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
1219
1220
1221
                # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
                next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
patrickvonplaten's avatar
patrickvonplaten committed
1222

1223
            else:
1224
                next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
patrickvonplaten's avatar
patrickvonplaten committed
1225

1226
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
1227
1228
1229
                next_scores = next_scores.view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)
1230

1231
                next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
thomwolf's avatar
thomwolf committed
1232

1233
            assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
thomwolf's avatar
thomwolf committed
1234
1235
1236
1237
1238

            # next batch beam content
            next_batch_beam = []

            # for each sentence
1239
            for batch_idx in range(batch_size):
thomwolf's avatar
thomwolf committed
1240
1241

                # if we are done with this sentence
1242
1243
1244
1245
1246
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                    assert (
1247
                        eos_token_id is not None and pad_token_id is not None
1248
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
thomwolf's avatar
thomwolf committed
1249
1250
1251
1252
1253
1254
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

1255
                # next tokens for this sentence
1256
1257
1258
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
                ):
thomwolf's avatar
thomwolf committed
1259
                    # get beam and word IDs
1260
1261
                    beam_id = beam_token_id // vocab_size
                    token_id = beam_token_id % vocab_size
thomwolf's avatar
thomwolf committed
1262

1263
                    effective_beam_id = batch_idx * num_beams + beam_id
patrickvonplaten's avatar
patrickvonplaten committed
1264

1265
                    # add to generated hypotheses if end of sentence
1266
                    if (eos_token_id is not None) and (token_id.item() is eos_token_id):
1267
1268
1269
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                        if is_beam_token_worse_than_top_num_beams:
patrickvonplaten's avatar
patrickvonplaten committed
1270
                            continue
1271
                        generated_hyps[batch_idx].add(
1272
                            input_ids[effective_beam_id].clone(), beam_token_score.item(),
1273
                        )
thomwolf's avatar
thomwolf committed
1274
                    else:
1275
                        # add next predicted word if it is not eos_token
1276
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
thomwolf's avatar
thomwolf committed
1277
1278
1279
1280
1281

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

patrickvonplaten's avatar
patrickvonplaten committed
1282
1283
1284
1285
1286
                # Check if were done so that we can save a pad step if all(done)
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                    next_scores[batch_idx].max().item(), cur_len=cur_len
                )

thomwolf's avatar
thomwolf committed
1287
                # update next beam content
1288
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
thomwolf's avatar
thomwolf committed
1289
                next_batch_beam.extend(next_sent_beam)
1290
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)
thomwolf's avatar
thomwolf committed
1291

patrickvonplaten's avatar
patrickvonplaten committed
1292
1293
1294
1295
            # stop when we are done with each sentence
            if all(done):
                break

thomwolf's avatar
thomwolf committed
1296
1297
1298
            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
1299
            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
thomwolf's avatar
thomwolf committed
1300
1301
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])

1302
            # re-order batch
thomwolf's avatar
thomwolf committed
1303
            input_ids = input_ids[beam_idx, :]
1304
            input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
1305
1306

            # re-order internal states
1307
            if past is not None:
Sam Shleifer's avatar
Sam Shleifer committed
1308
                past = self._reorder_cache(past, beam_idx)
thomwolf's avatar
thomwolf committed
1309

patrickvonplaten's avatar
patrickvonplaten committed
1310
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1311
            if self.config.is_encoder_decoder is False:
1312
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1313
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1314
                )
Patrick von Platen's avatar
Patrick von Platen committed
1315

1316
1317
1318
1319
            # update current length
            cur_len = cur_len + 1

        # finalize all open beam hypotheses and end to generated hypotheses
1320
        for batch_idx in range(batch_size):
1321
1322
            if done[batch_idx]:
                continue
1323

1324
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
1325
1326
            if eos_token_id is not None and all(
                (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
1327
1328
1329
1330
            ):
                assert torch.all(
                    next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
Patrick von Platen's avatar
Patrick von Platen committed
1331
                    next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
1332
1333
1334
1335
1336
1337
1338
1339
                )

            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[batch_idx].add(final_tokens, final_score)
thomwolf's avatar
thomwolf committed
1340

1341
1342
1343
        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
thomwolf's avatar
thomwolf committed
1344
1345

        # select the best hypotheses
1346
        sent_lengths = input_ids.new(output_batch_size)
thomwolf's avatar
thomwolf committed
1347
        best = []
thomwolf's avatar
thomwolf committed
1348

1349
        # retrieve best hypotheses
thomwolf's avatar
thomwolf committed
1350
        for i, hypotheses in enumerate(generated_hyps):
1351
1352
1353
1354
1355
1356
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                effective_batch_idx = output_num_return_sequences_per_batch * i + j
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[effective_batch_idx] = len(best_hyp)
                best.append(best_hyp)
thomwolf's avatar
thomwolf committed
1357

1358
1359
1360
        # shorter batches are filled with pad_token
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
1361
            sent_max_len = min(sent_lengths.max().item() + 1, max_length)
1362
            decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
1363
1364
1365
1366
1367

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
                decoded[i, : sent_lengths[i]] = hypo
                if sent_lengths[i] < max_length:
1368
                    decoded[i, sent_lengths[i]] = eos_token_id
1369
1370
1371
1372
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
thomwolf's avatar
thomwolf committed
1373

Patrick von Platen's avatar
Patrick von Platen committed
1374
        return decoded
thomwolf's avatar
thomwolf committed
1375

Patrick von Platen's avatar
Patrick von Platen committed
1376
    # force one of token_ids to be generated by setting prob of all other tokens to 0.
patrickvonplaten's avatar
patrickvonplaten committed
1377
    def _force_token_ids_generation(self, scores, token_ids):
Patrick von Platen's avatar
Patrick von Platen committed
1378
1379
1380
1381
1382
1383
1384
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        all_but_token_ids_mask = torch.tensor(
            [x for x in range(self.config.vocab_size) if x not in token_ids],
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
patrickvonplaten's avatar
patrickvonplaten committed
1385
        assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
1386
        scores[:, all_but_token_ids_mask] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1387

Sam Shleifer's avatar
Sam Shleifer committed
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = []
        for layer_past in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` and `mems` is at 2nd position
            reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
            reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
            # check that shape matches
            assert reordered_layer_past.shape == layer_past.shape
            reordered_past.append(reordered_layer_past)
        past = tuple(reordered_past)
        return past

thomwolf's avatar
thomwolf committed
1402

patrickvonplaten's avatar
patrickvonplaten committed
1403
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
Patrick von Platen's avatar
Patrick von Platen committed
1404
    # Copied from fairseq for no_repeat_ngram in beam_search"""
patrickvonplaten's avatar
patrickvonplaten committed
1405
    if cur_len + 1 < no_repeat_ngram_size:
Patrick von Platen's avatar
Patrick von Platen committed
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
patrickvonplaten's avatar
patrickvonplaten committed
1418
        start_idx = cur_len + 1 - no_repeat_ngram_size
1419
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
Patrick von Platen's avatar
Patrick von Platen committed
1420
1421
1422
1423
1424
1425
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


1426
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
thomwolf's avatar
thomwolf committed
1427
1428
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
thomwolf's avatar
thomwolf committed
1429
            logits: logits distribution shape (batch size, vocabulary size)
1430
1431
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
thomwolf's avatar
thomwolf committed
1432
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
thomwolf's avatar
thomwolf committed
1433
            Make sure we keep at least min_tokens_to_keep per batch example in the output
thomwolf's avatar
thomwolf committed
1434
1435
1436
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
thomwolf's avatar
thomwolf committed
1437
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
thomwolf's avatar
thomwolf committed
1438
1439
1440
1441
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

1442
    if top_p < 1.0:
thomwolf's avatar
thomwolf committed
1443
1444
1445
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

thomwolf's avatar
thomwolf committed
1446
        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
thomwolf's avatar
thomwolf committed
1447
        sorted_indices_to_remove = cumulative_probs > top_p
thomwolf's avatar
thomwolf committed
1448
1449
1450
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
thomwolf's avatar
thomwolf committed
1451
1452
1453
1454
1455
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
1456
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
thomwolf's avatar
thomwolf committed
1457
1458
        logits[indices_to_remove] = filter_value
    return logits
thomwolf's avatar
thomwolf committed
1459
1460
1461


class BeamHypotheses(object):
1462
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
thomwolf's avatar
thomwolf committed
1463
1464
1465
1466
1467
1468
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
1469
1470
        self.num_beams = num_beams
        self.beams = []
thomwolf's avatar
thomwolf committed
1471
1472
1473
1474
1475
1476
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
1477
        return len(self.beams)
thomwolf's avatar
thomwolf committed
1478

thomwolf's avatar
thomwolf committed
1479
1480
1481
1482
1483
    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
1484
1485
1486
1487
1488
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
thomwolf's avatar
thomwolf committed
1489
1490
1491
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
thomwolf's avatar
thomwolf committed
1492

Sam Shleifer's avatar
Sam Shleifer committed
1493
    def is_done(self, best_sum_logprobs, cur_len=None):
thomwolf's avatar
thomwolf committed
1494
1495
1496
1497
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
Sam Shleifer's avatar
Sam Shleifer committed
1498

1499
        if len(self) < self.num_beams:
thomwolf's avatar
thomwolf committed
1500
1501
1502
1503
            return False
        elif self.early_stopping:
            return True
        else:
Sam Shleifer's avatar
Sam Shleifer committed
1504
1505
1506
1507
1508
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret
thomwolf's avatar
thomwolf committed
1509
1510


thomwolf's avatar
thomwolf committed
1511
1512
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
thomwolf's avatar
thomwolf committed
1513
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
thomwolf's avatar
thomwolf committed
1514
1515
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
1516
        super().__init__()
thomwolf's avatar
thomwolf committed
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


thomwolf's avatar
thomwolf committed
1530
1531
class PoolerStartLogits(nn.Module):
    """ Compute SQuAD start_logits from sequence hidden states. """
1532

thomwolf's avatar
thomwolf committed
1533
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1534
        super().__init__()
thomwolf's avatar
thomwolf committed
1535
1536
1537
1538
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, p_mask=None):
        """ Args:
1539
1540
1541
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
                invalid position mask such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1542
        """
thomwolf's avatar
thomwolf committed
1543
1544
1545
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
1546
1547
1548
1549
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1550
1551
1552
1553
1554
1555
1556

        return x


class PoolerEndLogits(nn.Module):
    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
    """
1557

thomwolf's avatar
thomwolf committed
1558
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1559
        super().__init__()
thomwolf's avatar
thomwolf committed
1560
1561
1562
1563
1564
1565
1566
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
        """ Args:
1567
1568
1569
1570
1571
1572
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
1573
                position of the first token for the labeled span:
1574
1575
1576
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1577
        """
1578
1579
1580
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1581
        if start_positions is not None:
1582
            slen, hsz = hidden_states.shape[-2:]
1583
1584
1585
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)
thomwolf's avatar
thomwolf committed
1586
1587
1588
1589
1590
1591
1592

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
1593
1594
1595
1596
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1597
1598
1599
1600
1601
1602

        return x


class PoolerAnswerClass(nn.Module):
    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
1603

thomwolf's avatar
thomwolf committed
1604
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1605
        super().__init__()
thomwolf's avatar
thomwolf committed
1606
1607
1608
1609
1610
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
thomwolf's avatar
thomwolf committed
1626
        """
1627
        hsz = hidden_states.shape[-1]
1628
1629
1630
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1631
        if start_positions is not None:
1632
1633
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1634
1635

        if cls_index is not None:
1636
1637
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1638
        else:
1639
            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1640
1641
1642
1643
1644
1645
1646
1647
1648

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x


class SQuADHead(nn.Module):
1649
1650
1651
    r""" A SQuAD head inspired by XLNet.

    Parameters:
1652
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671

    Inputs:
        **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
            hidden states of sequence tokens
        **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the first token for the labeled span.
        **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the last token for the labeled span.
        **cls_index**: torch.LongTensor of shape ``(batch_size,)``
            position of the CLS token. If None, take the last token.
        **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
            Whether the question has a possible answer in the paragraph or not.
        **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
            Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
            1.0 means token should be masked.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
thomwolf's avatar
thomwolf committed
1672
        **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1673
1674
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1675
        **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1676
1677
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
            Indices for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1678
        **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1679
1680
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1681
        **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1682
1683
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1684
        **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1685
1686
            ``torch.FloatTensor`` of shape ``(batch_size,)``
            Log probabilities for the ``is_impossible`` label of the answers.
thomwolf's avatar
thomwolf committed
1687
    """
1688

thomwolf's avatar
thomwolf committed
1689
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1690
        super().__init__()
thomwolf's avatar
thomwolf committed
1691
1692
1693
1694
1695
1696
1697
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

1698
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
1699
        self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1700
    ):
thomwolf's avatar
thomwolf committed
1701
1702
        outputs = ()

thomwolf's avatar
thomwolf committed
1703
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
thomwolf's avatar
thomwolf committed
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index, is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5
1727
1728

            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1729
1730
1731
1732

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
            start_log_probs = F.softmax(start_logits, dim=-1)  # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(
                start_log_probs, self.start_n_top, dim=-1
            )  # shape (bsz, start_n_top)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, start_n_top, hsz)
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states
            )  # shape (bsz, slen, start_n_top, hsz)
thomwolf's avatar
thomwolf committed
1745
1746
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1747
            end_log_probs = F.softmax(end_logits, dim=1)  # shape (bsz, slen, start_n_top)
thomwolf's avatar
thomwolf committed
1748

1749
1750
1751
            end_top_log_probs, end_top_index = torch.topk(
                end_log_probs, self.end_n_top, dim=1
            )  # shape (bsz, end_n_top, start_n_top)
thomwolf's avatar
thomwolf committed
1752
1753
1754
1755
1756
1757
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)

Patrick von Platen's avatar
Patrick von Platen committed
1758
            outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
thomwolf's avatar
thomwolf committed
1759
1760

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1761
        # or (if labels are provided) (total_loss,)
thomwolf's avatar
thomwolf committed
1762
1763
1764
1765
        return outputs


class SequenceSummary(nn.Module):
thomwolf's avatar
thomwolf committed
1766
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
thomwolf's avatar
thomwolf committed
1767
1768
1769
1770
1771
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
thomwolf's avatar
thomwolf committed
1772
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
thomwolf's avatar
thomwolf committed
1773
1774
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
1775
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1776
            summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1777
1778
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
thomwolf's avatar
thomwolf committed
1779
    """
1780

1781
    def __init__(self, config: PretrainedConfig):
Julien Chaumond's avatar
Julien Chaumond committed
1782
        super().__init__()
thomwolf's avatar
thomwolf committed
1783

1784
        self.summary_type = getattr(config, "summary_type", "last")
1785
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1786
1787
1788
1789
1790
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

thomwolf's avatar
thomwolf committed
1791
        self.summary = Identity()
1792
1793
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1794
                num_classes = config.num_labels
thomwolf's avatar
thomwolf committed
1795
1796
1797
1798
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

1799
1800
1801
1802
        activation_string = getattr(config, "summary_activation", None)
        self.activation = (
            get_activation(activation_string) if activation_string else Identity()
        )  # type: typing.Callable
thomwolf's avatar
thomwolf committed
1803

thomwolf's avatar
thomwolf committed
1804
        self.first_dropout = Identity()
1805
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1806
1807
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

thomwolf's avatar
thomwolf committed
1808
        self.last_dropout = Identity()
1809
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1810
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
thomwolf's avatar
thomwolf committed
1811

thomwolf's avatar
thomwolf committed
1812
    def forward(self, hidden_states, cls_index=None):
1813
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
thomwolf's avatar
thomwolf committed
1814
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
thomwolf's avatar
thomwolf committed
1815
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
thomwolf's avatar
thomwolf committed
1816
                if summary_type == 'cls_index' and cls_index is None:
thomwolf's avatar
thomwolf committed
1817
1818
                    we take the last token of the sequence as classification token
        """
1819
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
1820
            output = hidden_states[:, -1]
1821
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
1822
            output = hidden_states[:, 0]
1823
        elif self.summary_type == "mean":
thomwolf's avatar
thomwolf committed
1824
            output = hidden_states.mean(dim=1)
1825
        elif self.summary_type == "cls_index":
thomwolf's avatar
thomwolf committed
1826
            if cls_index is None:
Patrick von Platen's avatar
Patrick von Platen committed
1827
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
thomwolf's avatar
thomwolf committed
1828
            else:
thomwolf's avatar
thomwolf committed
1829
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1830
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
1831
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1832
1833
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
1834
1835
            raise NotImplementedError

1836
        output = self.first_dropout(output)
thomwolf's avatar
thomwolf committed
1837
1838
        output = self.summary(output)
        output = self.activation(output)
1839
        output = self.last_dropout(output)
thomwolf's avatar
thomwolf committed
1840
1841
1842
1843

        return output


Sam Shleifer's avatar
Sam Shleifer committed
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.

    :param torch.Tensor x:
    :return torch.Tensor:
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = input_ids.ne(padding_idx).int()
    incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
    return incremental_indicies.long() + padding_idx


1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
def prune_linear_layer(layer, index, dim=0):
    """ Prune a linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917


def prune_layer(layer, index, dim=None):
    """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    if isinstance(layer, nn.Linear):
        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
    elif isinstance(layer, Conv1D):
        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
    else:
        raise ValueError("Can't prune layer of class {}".format(layer.__class__))