"docs/vscode:/vscode.git/clone" did not exist on "bbe9c6981b96bf38f74a2460e3c98baef0b43e17"
modeling_utils.py 106 KB
Newer Older
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
17
import inspect
18
19
import logging
import os
20
from typing import Callable, Dict, Iterable, List, Optional, Tuple
21
22

import torch
23
from torch import Tensor, device, dtype, nn
24
25
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
26

27
from .activations import get_activation
28
from .configuration_utils import PretrainedConfig
29
from .file_utils import (
Aymeric Augustin's avatar
Aymeric Augustin committed
30
    DUMMY_INPUTS,
31
32
33
34
35
36
37
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    WEIGHTS_NAME,
    cached_path,
    hf_bucket_url,
    is_remote_url,
)
38

Aymeric Augustin's avatar
Aymeric Augustin committed
39

40
41
logger = logging.getLogger(__name__)

42

thomwolf's avatar
thomwolf committed
43
44
45
46
47
48
49
try:
    from torch.nn import Identity
except ImportError:
    # Older PyTorch compatibility
    class Identity(nn.Module):
        r"""A placeholder identity operator that is argument-insensitive.
        """
50

thomwolf's avatar
thomwolf committed
51
        def __init__(self, *args, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
52
            super().__init__()
thomwolf's avatar
thomwolf committed
53
54
55
56

        def forward(self, input):
            return input

57

58
class ModuleUtilsMixin:
Julien Chaumond's avatar
Julien Chaumond committed
59
60
61
62
63
64
65
66
67
68
69
    """
    A few utilities for torch.nn.Modules, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the module.
        """
        params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
        return sum(p.numel() for p in params)

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    @staticmethod
    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_pre_forward = mem.rss
        return None

    @staticmethod
    def _hook_rss_memory_post_forward(module, *args, **kwargs):
        try:
            import psutil
        except (ImportError):
            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")

        process = psutil.Process(os.getpid())
        mem = process.memory_info()
        module.mem_rss_post_forward = mem.rss
        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
        return None

    def add_memory_hooks(self):
        """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
            Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
        """
        for module in self.modules():
            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
            module.register_forward_hook(self._hook_rss_memory_post_forward)
        self.reset_memory_hooks_state()

    def reset_memory_hooks_state(self):
        for module in self.modules():
            module.mem_rss_diff = 0
            module.mem_rss_post_forward = 0
            module.mem_rss_pre_forward = 0

111
    @property
112
    def device(self) -> device:
113
114
115
        """
        Get torch.device from module, assuming that the whole module has one device.
        """
116
117
118
119
120
121
122
123
124
125
126
127
        try:
            return next(self.parameters()).device
        except StopIteration:
            # For nn.DataParallel compatibility in PyTorch 1.5

            def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
                tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
                return tuples

            gen = self._named_members(get_members_fn=find_tensor_attributes)
            first_tuple = next(gen)
            return first_tuple[1].device
128

129
130
    @property
    def dtype(self) -> dtype:
131
132
133
        """
        Get torch.dtype from module, assuming that the whole module has one dtype.
        """
134
135
136
137
138
139
140
141
142
143
144
145
        try:
            return next(self.parameters()).dtype
        except StopIteration:
            # For nn.DataParallel compatibility in PyTorch 1.5

            def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
                tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
                return tuples

            gen = self._named_members(get_members_fn=find_tensor_attributes)
            first_tuple = next(gen)
            return first_tuple[1].dtype
146
147
148
149
150
151
152
153
154
155
156
157
158

    def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
        """type: torch.Tensor -> torch.Tensor"""
        if encoder_attention_mask.dim() == 3:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
        if encoder_attention_mask.dim() == 2:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
        # /transformer/transformer_layers.py#L270
        # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
        # encoder_extended_attention_mask.transpose(-1, -2))
        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
159
160
161
162
163
164
165
166
167
168
169
170

        if self.dtype == torch.float16:
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
        elif self.dtype == torch.float32:
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
        else:
            raise ValueError(
                "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
                    self.dtype
                )
            )

171
172
        return encoder_extended_attention_mask

173
    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.

        Arguments:
            attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
            input_shape: tuple, shape of input_ids
            device: torch.Device, usually self.device

        Returns:
            torch.Tensor with dtype of attention_mask.dtype
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

217
    def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
218
219
220
221
222
223
224
225
226
227
228
229
230
        """
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        attention_probs has shape bsz x n_heads x N x N
        Arguments:
            head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
            num_hidden_layers: int
        Returns:
             Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
             or list with [None] for each layer
        """
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
Patrick von Platen's avatar
Patrick von Platen committed
231
232
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        else:
            head_mask = [None] * num_hidden_layers

        return head_mask

    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
        head_mask = head_mask.to(dtype=self.dtype)  # switch to fload if need + fp16 compatibility
        return head_mask

Julien Chaumond's avatar
Julien Chaumond committed
249

250
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
251
252
    r""" Base class for all models.

253
        :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
254
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
255
256

        Class attributes (overridden by derived classes):
257
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
258
259
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

260
261
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
262
263
264
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
265
    """
266
    config_class = None
267
268
    base_model_prefix = ""

269
270
271
272
273
274
275
    @property
    def dummy_inputs(self):
        """ Dummy inputs to do a forward pass in the network.

        Returns:
            torch.Tensor with dummy inputs
        """
276
        return {"input_ids": torch.tensor(DUMMY_INPUTS)}
277

278
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
279
        super().__init__()
280
281
282
283
284
285
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
286
287
                )
            )
thomwolf's avatar
thomwolf committed
288
        # Save config in model
289
290
        self.config = config

291
292
293
    @property
    def base_model(self):
        return getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
294

thomwolf's avatar
thomwolf committed
295
    def get_input_embeddings(self):
296
297
298
299
300
301
        """
        Returns the model's input embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
302
        """
303
        base_model = getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
304
305
306
307
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
308

309
    def set_input_embeddings(self, value: nn.Module):
310
311
312
313
314
315
        """
        Set model's input embeddings

        Args:
            value (:obj:`nn.Module`):
                A module mapping vocabulary to hidden states.
thomwolf's avatar
thomwolf committed
316
317
318
319
320
321
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
322

thomwolf's avatar
thomwolf committed
323
    def get_output_embeddings(self):
324
325
326
327
328
329
        """
        Returns the model's output embeddings.

        Returns:
            :obj:`nn.Module`:
                A torch module mapping hidden states to vocabulary.
thomwolf's avatar
thomwolf committed
330
        """
331
        return None  # Overwrite for models with output embeddings
thomwolf's avatar
thomwolf committed
332

333
    def tie_weights(self):
334
335
336
337
        """
        Tie the weights between the input embeddings and the output embeddings.
        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
        the weights instead.
thomwolf's avatar
thomwolf committed
338
        """
thomwolf's avatar
thomwolf committed
339
340
341
        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
thomwolf's avatar
thomwolf committed
342

343
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
sshleifer's avatar
sshleifer committed
344
        """ Tie or clone module weights depending of whether we are using TorchScript or not
thomwolf's avatar
thomwolf committed
345
346
        """
        if self.config.torchscript:
347
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
thomwolf's avatar
thomwolf committed
348
        else:
349
            output_embeddings.weight = input_embeddings.weight
thomwolf's avatar
thomwolf committed
350

Sam Shleifer's avatar
Sam Shleifer committed
351
        if getattr(output_embeddings, "bias", None) is not None:
352
353
            output_embeddings.bias.data = torch.nn.functional.pad(
                output_embeddings.bias.data,
Patrick von Platen's avatar
Patrick von Platen committed
354
                (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
355
356
                "constant",
                0,
357
            )
358
        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
359
            output_embeddings.out_features = input_embeddings.num_embeddings
360

361
    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
thomwolf's avatar
thomwolf committed
362
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
363
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
thomwolf's avatar
thomwolf committed
364

365
366
367
        Arguments:

            new_num_tokens: (`optional`) int:
368
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
369
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
thomwolf's avatar
thomwolf committed
370

thomwolf's avatar
thomwolf committed
371
        Return: ``torch.nn.Embeddings``
372
            Pointer to the input tokens Embeddings Module of the model
thomwolf's avatar
thomwolf committed
373
374
        """
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
thomwolf's avatar
thomwolf committed
375
376
377
        model_embeds = base_model._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds
thomwolf's avatar
thomwolf committed
378
379
380
381
382
383

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens

        # Tie weights again if needed
384
        self.tie_weights()
thomwolf's avatar
thomwolf committed
385

thomwolf's avatar
thomwolf committed
386
387
        return model_embeds

388
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
389
390
391
392
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_input_embeddings(new_embeddings)
        return self.get_input_embeddings()
393

394
395
396
    def _get_resized_embeddings(
        self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
    ) -> torch.nn.Embedding:
397
398
399
400
401
        """ Build a resized Embedding Module from a provided token Embedding Module.
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
402
403
            old_embeddings: ``torch.nn.Embedding``
                Old embeddings to be resized.
404
405
406
407
408
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
409
        Return: ``torch.nn.Embedding``
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
        if new_num_tokens is None:
            return old_embeddings

        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        if old_num_tokens == new_num_tokens:
            return old_embeddings

        # Build new embeddings
        new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        new_embeddings.to(old_embeddings.weight.device)

        # initialize all new embeddings (in particular added tokens)
        self._init_weights(new_embeddings)

426
        # Copy token embeddings from the previous weights
427
428
429
430
431
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        return new_embeddings

432
433
434
435
436
437
438
439
440
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

441
442
443
        # Tie weights if needed
        self.tie_weights()

444
    def prune_heads(self, heads_to_prune: Dict):
thomwolf's avatar
thomwolf committed
445
        """ Prunes heads of the base model.
446
447
448
449

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
450
                E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
451
        """
452
        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
453
        for layer, heads in heads_to_prune.items():
454
455
456
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON

457
        self.base_model._prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
458

459
    def save_pretrained(self, save_directory):
460
        """ Save a model and its configuration file to a directory, so that it
461
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
462
463
464

            Arguments:
                save_directory: directory to which to save.
465
        """
466
467
468
        assert os.path.isdir(
            save_directory
        ), "Saving path should be a directory where the model and configuration can be saved"
469

Julien Chaumond's avatar
Julien Chaumond committed
470
        # Only save the model itself if we are using distributed training
471
        model_to_save = self.module if hasattr(self, "module") else self
472

Julien Chaumond's avatar
Julien Chaumond committed
473
474
475
        # Attach architecture to the config
        model_to_save.config.architectures = [model_to_save.__class__.__name__]

476
477
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
478

479
        if getattr(self.config, "xla_device", False):
480
481
482
483
484
485
486
487
488
489
490
            import torch_xla.core.xla_model as xm

            if xm.is_master_ordinal():
                # Save configuration file
                model_to_save.config.save_pretrained(save_directory)
            # xm.save takes care of saving only from master
            xm.save(model_to_save.state_dict(), output_model_file)
        else:
            model_to_save.config.save_pretrained(save_directory)
            torch.save(model_to_save.state_dict(), output_model_file)

thomwolf's avatar
thomwolf committed
491
        logger.info("Model weights saved in {}".format(output_model_file))
492

493
    @classmethod
494
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
495
496
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

497
498
499
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

500
501
502
503
504
        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

505
506
        Parameters:
            pretrained_model_name_or_path: either:
Lysandre's avatar
Fixes  
Lysandre committed
507
508
509
510
511
              - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
              - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
              - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
              - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
              - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
512
513
514
515

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

516
            config: (`optional`) one of:
Lysandre's avatar
Fixes  
Lysandre committed
517
518
                - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
                - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
519
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
Lysandre's avatar
Fixes  
Lysandre committed
520
521
522
                    - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
                    - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
                    - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
523
524
525

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
526
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
527
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
528
529

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
530
531
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
532

533
534
535
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

536
537
538
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

539
540
541
542
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

543
            output_loading_info: (`optional`) boolean:
thomwolf's avatar
thomwolf committed
544
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
545
546
547
548
549

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
550
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
551
552

        Examples::
thomwolf's avatar
thomwolf committed
553

Lysandre's avatar
Lysandre committed
554
            # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
555
556
557
558
559
560
561
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
562

563
        """
564
565
566
567
568
569
570
571
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
572
        local_files_only = kwargs.pop("local_files_only", False)
Julien Chaumond's avatar
Julien Chaumond committed
573
        use_cdn = kwargs.pop("use_cdn", True)
thomwolf's avatar
thomwolf committed
574

575
576
577
        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
578
            config, model_kwargs = cls.config_class.from_pretrained(
579
580
581
582
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
583
                force_download=force_download,
584
                resume_download=resume_download,
585
                proxies=proxies,
586
                local_files_only=local_files_only,
587
                **kwargs,
588
589
590
            )
        else:
            model_kwargs = kwargs
591

thomwolf's avatar
thomwolf committed
592
        # Load model
thomwolf's avatar
thomwolf committed
593
        if pretrained_model_name_or_path is not None:
594
            if os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
595
596
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
thomwolf's avatar
thomwolf committed
597
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
thomwolf's avatar
thomwolf committed
598
599
600
601
602
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
603
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
thomwolf's avatar
thomwolf committed
604
                else:
605
606
                    raise EnvironmentError(
                        "Error no file named {} found in directory {} or `from_tf` set to False".format(
Patrick von Platen's avatar
Patrick von Platen committed
607
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
Patrick von Platen's avatar
Patrick von Platen committed
608
                            pretrained_model_name_or_path,
609
610
                        )
                    )
611
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
612
                archive_file = pretrained_model_name_or_path
613
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
614
615
616
617
618
                assert (
                    from_tf
                ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index"
                )
619
                archive_file = pretrained_model_name_or_path + ".index"
620
            else:
thomwolf's avatar
thomwolf committed
621
                archive_file = hf_bucket_url(
Julien Chaumond's avatar
Julien Chaumond committed
622
623
624
                    pretrained_model_name_or_path,
                    filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
                    use_cdn=use_cdn,
thomwolf's avatar
thomwolf committed
625
                )
626

thomwolf's avatar
thomwolf committed
627
            try:
628
                # Load from URL or cache if already cached
629
630
631
632
633
634
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
635
                    local_files_only=local_files_only,
636
                )
637
638
                if resolved_archive_file is None:
                    raise EnvironmentError
thomwolf's avatar
thomwolf committed
639
            except EnvironmentError:
640
641
642
643
644
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
                )
thomwolf's avatar
thomwolf committed
645
646
                raise EnvironmentError(msg)

thomwolf's avatar
thomwolf committed
647
648
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
649
            else:
650
                logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
651
        else:
thomwolf's avatar
thomwolf committed
652
            resolved_archive_file = None
653
654

        # Instantiate model.
655
        model = cls(config, *model_args, **model_kwargs)
thomwolf's avatar
thomwolf committed
656

657
        if state_dict is None and not from_tf:
658
            try:
659
                state_dict = torch.load(resolved_archive_file, map_location="cpu")
660
            except Exception:
661
662
663
664
                raise OSError(
                    "Unable to load weights from pytorch checkpoint file. "
                    "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                )
665

666
667
668
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
669
670

        if from_tf:
671
            if resolved_archive_file.endswith(".index"):
672
673
674
675
676
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
677
                    from transformers import load_tf2_checkpoint_in_pytorch_model
678

679
                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
680
                except ImportError:
681
682
683
684
                    logger.error(
                        "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
                    )
685
                    raise
686
687
688
689
690
691
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
692
693
694
695
                if "gamma" in key:
                    new_key = key.replace("gamma", "weight")
                if "beta" in key:
                    new_key = key.replace("beta", "bias")
696
697
698
699
700
701
702
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
703
            metadata = getattr(state_dict, "_metadata", None)
704
705
706
707
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

708
709
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
Julien Chaumond's avatar
Julien Chaumond committed
710
            def load(module: nn.Module, prefix=""):
711
712
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
Patrick von Platen's avatar
Patrick von Platen committed
713
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
714
                )
715
716
                for name, child in module._modules.items():
                    if child is not None:
717
                        load(child, prefix + name + ".")
718
719

            # Make sure we are able to load base models as well as derived models (with heads)
720
            start_prefix = ""
721
            model_to_load = model
722
723
            has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
            if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
724
                start_prefix = cls.base_model_prefix + "."
725
            if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
726
727
728
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
729
730
731
732
733
734
735
736
737

            if model.__class__.__name__ != model_to_load.__class__.__name__:
                base_model_state_dict = model_to_load.state_dict().keys()
                head_model_state_dict_without_base_prefix = [
                    key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
                ]

                missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)

738
            if len(missing_keys) > 0:
739
740
741
742
743
                logger.info(
                    "Weights of {} not initialized from pretrained model: {}".format(
                        model.__class__.__name__, missing_keys
                    )
                )
744
            if len(unexpected_keys) > 0:
745
746
747
748
749
                logger.info(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys
                    )
                )
750
            if len(error_msgs) > 0:
751
752
753
754
755
                raise RuntimeError(
                    "Error(s) in loading state_dict for {}:\n\t{}".format(
                        model.__class__.__name__, "\n\t".join(error_msgs)
                    )
                )
756
        model.tie_weights()  # make sure token embedding weights are still tied if needed
757

758
        # Set model in evaluation mode to deactivate DropOut modules by default
759
760
        model.eval()

thomwolf's avatar
thomwolf committed
761
        if output_loading_info:
762
763
764
765
766
            loading_info = {
                "missing_keys": missing_keys,
                "unexpected_keys": unexpected_keys,
                "error_msgs": error_msgs,
            }
thomwolf's avatar
thomwolf committed
767
768
            return model, loading_info

769
770
771
772
        if hasattr(config, "xla_device") and config.xla_device:
            import torch_xla.core.xla_model as xm

            model = xm.send_cpu_data_to_device(model, xm.xla_device())
773
            model.to(xm.xla_device())
774

775
776
        return model

thomwolf's avatar
thomwolf committed
777
778
779
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

780
781
    def prepare_logits_for_generation(self, logits, **kwargs):
        return logits
patrickvonplaten's avatar
patrickvonplaten committed
782

783
    def _use_cache(self, outputs, use_cache):
Sam Shleifer's avatar
Sam Shleifer committed
784
        """During generation, decide whether to pass the `past` variable to the next forward pass."""
785
        if len(outputs) <= 1 or use_cache is False:
Sam Shleifer's avatar
Sam Shleifer committed
786
            return False
787
788
789
        if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
            return False
        return True
790

Sam Shleifer's avatar
Sam Shleifer committed
791
792
793
794
795
796
797
798
799
800
    def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty

thomwolf's avatar
thomwolf committed
801
    @torch.no_grad()
802
803
    def generate(
        self,
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        input_ids: Optional[torch.LongTensor] = None,
        max_length: Optional[int] = None,
        min_length: Optional[int] = None,
        do_sample: Optional[bool] = None,
        early_stopping: Optional[bool] = None,
        num_beams: Optional[int] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None,
        bos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        no_repeat_ngram_size: Optional[int] = None,
        num_return_sequences: Optional[int] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_start_token_id: Optional[int] = None,
        use_cache: Optional[bool] = None,
Patrick von Platen's avatar
Patrick von Platen committed
824
        **model_specific_kwargs
825
    ) -> torch.LongTensor:
826
        r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
thomwolf's avatar
thomwolf committed
827

828
829
830
831
832
833
834
        Adapted in part from `Facebook's XLM beam search code`_.

        .. _`Facebook's XLM beam search code`:
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


        Parameters:
thomwolf's avatar
thomwolf committed
835

836
            input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
thomwolf's avatar
thomwolf committed
837
                The sequence used as a prompt for the generation. If `None` the method initializes
838
839
840
                it as an empty `torch.LongTensor` of shape `(1,)`.

            max_length: (`optional`) int
841
842
843
844
                The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.

            min_length: (`optional`) int
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
845
846

            do_sample: (`optional`) bool
847
848
849
850
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

            early_stopping: (`optional`) bool
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
851
852
853
854
855

            num_beams: (`optional`) int
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

            temperature: (`optional`) float
Sam Shleifer's avatar
Sam Shleifer committed
856
                The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
857
858

            top_k: (`optional`) int
thomwolf's avatar
thomwolf committed
859
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
860
861

            top_p: (`optional`) float
thomwolf's avatar
thomwolf committed
862
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
863
864
865
866

            repetition_penalty: (`optional`) float
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

867
868
869
            pad_token_id: (`optional`) int
                Padding token. Default to specicic model pad_token_id or None if it does not exist.

870
            bos_token_id: (`optional`) int
871
                BOS token. Defaults to `bos_token_id` as defined in the models config.
872

873
874
            eos_token_id: (`optional`) int
                EOS token. Defaults to `eos_token_id` as defined in the models config.
875

876
            length_penalty: (`optional`) float
thomwolf's avatar
thomwolf committed
877
                Exponential penalty to the length. Default to 1.
878

879
880
            no_repeat_ngram_size: (`optional`) int
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
881
882
            bad_words_ids: (`optional`) list of lists of int
                `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
883

884
885
886
            num_return_sequences: (`optional`) int
                The number of independently computed returned sequences for each element in the batch. Default to 1.

887
888
889
890
891
892
            attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
                Mask to avoid performing attention on padding token indices.
                Mask values selected in ``[0, 1]``:
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
                Defaults to `None`.

893
                `What are attention masks? <../glossary.html#attention-mask>`__
894
895
896
897
898

            decoder_start_token_id=None: (`optional`) int
                If an encoder-decoder model starts decoding with a different token than BOS.
                Defaults to `None` and is changed to `BOS` later.

899
900
901
            use_cache: (`optional`) bool
                If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.

Patrick von Platen's avatar
Patrick von Platen committed
902
903
904
            model_specific_kwargs: (`optional`) dict
                Additional model specific kwargs will be forwarded to the `forward` function of the model.

905
906
907
908
909
        Return:

            output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

910
911
912
913
        Examples::

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
914
            outputs = model.generate(max_length=40)  # do greedy decoding
915
916
917
918
919
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
920
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
921
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
922
            for i in range(3): #  3 output sequences were generated
923
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
924
925
926
927

            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
            input_context = 'The dog'
928
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
929
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
930
931
            for i in range(3): #  3 output sequences were generated
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
932
933
934
935

            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
936
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
937
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
938
939
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

940
941
942
943
944
945
            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
            model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
thomwolf's avatar
thomwolf committed
946
947
948
949
        """

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
950
951
            raise AttributeError(
                "You tried to generate sequences with a model that does not have a LM Head."
952
                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
953
            )
thomwolf's avatar
thomwolf committed
954

955
        max_length = max_length if max_length is not None else self.config.max_length
956
        min_length = min_length if min_length is not None else self.config.min_length
957
        do_sample = do_sample if do_sample is not None else self.config.do_sample
Patrick von Platen's avatar
Patrick von Platen committed
958
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
959
        use_cache = use_cache if use_cache is not None else self.config.use_cache
960
961
962
963
964
965
966
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
967
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
968
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
Patrick von Platen's avatar
Patrick von Platen committed
969
970
971
        no_repeat_ngram_size = (
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        )
972
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
973
974
975
        num_return_sequences = (
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        )
976
977
978
        decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )
thomwolf's avatar
thomwolf committed
979
980
981

        if input_ids is not None:
            batch_size = input_ids.shape[0]  # overriden by the input batch_size
thomwolf's avatar
thomwolf committed
982
983
        else:
            batch_size = 1
thomwolf's avatar
thomwolf committed
984

Sam Shleifer's avatar
Sam Shleifer committed
985
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
986
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
thomwolf's avatar
thomwolf committed
987
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
Patrick von Platen's avatar
Patrick von Platen committed
988
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
989
        assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
Sam Shleifer's avatar
Sam Shleifer committed
990
991
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
        assert temperature > 0, "`temperature` should be strictly positive."
992
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
thomwolf's avatar
thomwolf committed
993
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
thomwolf's avatar
thomwolf committed
994
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
995
996
997
998
999
1000
        assert input_ids is not None or (
            isinstance(bos_token_id, int) and bos_token_id >= 0
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
        assert pad_token_id is None or (
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
        ), "`pad_token_id` should be a positive integer."
1001
1002
1003
        assert (eos_token_id is None) or (
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
        ), "`eos_token_id` should be a positive integer."
Sam Shleifer's avatar
Sam Shleifer committed
1004
        assert length_penalty > 0, "`length_penalty` should be strictly positive."
Patrick von Platen's avatar
Patrick von Platen committed
1005
1006
1007
        assert (
            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
        ), "`no_repeat_ngram_size` should be a positive integer."
1008
1009
        assert (
            isinstance(num_return_sequences, int) and num_return_sequences > 0
Sam Shleifer's avatar
Sam Shleifer committed
1010
        ), "`num_return_sequences` should be a strictly positive integer."
1011
1012
1013
        assert (
            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
thomwolf's avatar
thomwolf committed
1014
1015

        if input_ids is None:
1016
1017
1018
1019
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
                "you should either supply a context to complete as `input_ids` input "
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
            )
1020
            input_ids = torch.full(
Patrick von Platen's avatar
Patrick von Platen committed
1021
                (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
1022
            )
thomwolf's avatar
thomwolf committed
1023
        else:
1024
            assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
thomwolf's avatar
thomwolf committed
1025

1026
        # not allow to duplicate outputs when greedy decoding
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
        if do_sample is False:
            if num_beams == 1:
                # no_beam_search greedy generation conditions
                assert (
                    num_return_sequences == 1
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

            else:
                # beam_search greedy generation conditions
                assert (
                    num_beams >= num_return_sequences
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

Patrick von Platen's avatar
Patrick von Platen committed
1040
        # create attention mask if necessary
patrickvonplaten's avatar
patrickvonplaten committed
1041
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
Patrick von Platen's avatar
Patrick von Platen committed
1042
1043
1044
1045
1046
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
            attention_mask = input_ids.ne(pad_token_id).long()
        elif attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)

1047
        # set pad_token_id to eos_token_id if not set. Important that this is done after
Patrick von Platen's avatar
Patrick von Platen committed
1048
        # attention_mask is created
1049
        if pad_token_id is None and eos_token_id is not None:
1050
            logger.warning(
1051
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
1052
            )
1053
            pad_token_id = eos_token_id
1054

thomwolf's avatar
thomwolf committed
1055
        # current position and vocab size
1056
1057
1058
1059
1060
1061
1062
1063
        if hasattr(self.config, "vocab_size"):
            vocab_size = self.config.vocab_size
        elif (
            self.config.is_encoder_decoder
            and hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "vocab_size")
        ):
            vocab_size = self.config.decoder.vocab_size
thomwolf's avatar
thomwolf committed
1064

1065
1066
        # set effective batch size and effective batch multiplier according to do_sample
        if do_sample:
thomwolf's avatar
thomwolf committed
1067
            effective_batch_size = batch_size * num_return_sequences
1068
            effective_batch_mult = num_return_sequences
thomwolf's avatar
thomwolf committed
1069
1070
        else:
            effective_batch_size = batch_size
1071
1072
            effective_batch_mult = 1

1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        if self.config.is_encoder_decoder:
            if decoder_start_token_id is None:
                decoder_start_token_id = bos_token_id

            assert (
                decoder_start_token_id is not None
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

            # get encoder and store encoder outputs
            encoder = self.get_encoder()

1086
            encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
1087

1088
1089
1090
1091
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
        if num_return_sequences > 1 or num_beams > 1:
            input_ids_len = input_ids.shape[-1]
            input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
1092
1093
1094
            attention_mask = attention_mask.unsqueeze(1).expand(
                batch_size, effective_batch_mult * num_beams, input_ids_len
            )
Patrick von Platen's avatar
Patrick von Platen committed
1095

1096
1097
1098
            input_ids = input_ids.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
Patrick von Platen's avatar
Patrick von Platen committed
1099
1100
1101
            attention_mask = attention_mask.contiguous().view(
                effective_batch_size * num_beams, input_ids_len
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
1102

Patrick von Platen's avatar
Patrick von Platen committed
1103
        if self.config.is_encoder_decoder:
1104
            # create empty decoder_input_ids
Patrick von Platen's avatar
Patrick von Platen committed
1105
1106
            input_ids = torch.full(
                (effective_batch_size * num_beams, 1),
1107
                decoder_start_token_id,
Patrick von Platen's avatar
Patrick von Platen committed
1108
1109
1110
                dtype=torch.long,
                device=next(self.parameters()).device,
            )
1111
            cur_len = 1
1112

1113
            assert (
1114
1115
1116
1117
1118
                batch_size == encoder_outputs[0].shape[0]
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
            expanded_batch_idxs = (
1119
1120
1121
1122
1123
1124
                torch.arange(batch_size)
                .view(-1, 1)
                .repeat(1, num_beams * effective_batch_mult)
                .view(-1)
                .to(input_ids.device)
            )
1125
1126
            # expand encoder_outputs
            encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
1127

Patrick von Platen's avatar
Patrick von Platen committed
1128
        else:
1129
            encoder_outputs = None
Patrick von Platen's avatar
Patrick von Platen committed
1130
1131
            cur_len = input_ids.shape[-1]

thomwolf's avatar
thomwolf committed
1132
        if num_beams > 1:
1133
1134
            output = self._generate_beam_search(
                input_ids,
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                early_stopping=early_stopping,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
1145
                bad_words_ids=bad_words_ids,
1146
1147
                bos_token_id=bos_token_id,
                pad_token_id=pad_token_id,
1148
                decoder_start_token_id=decoder_start_token_id,
1149
                eos_token_id=eos_token_id,
1150
1151
1152
1153
1154
                batch_size=effective_batch_size,
                num_return_sequences=num_return_sequences,
                length_penalty=length_penalty,
                num_beams=num_beams,
                vocab_size=vocab_size,
1155
                encoder_outputs=encoder_outputs,
1156
                attention_mask=attention_mask,
1157
                use_cache=use_cache,
Patrick von Platen's avatar
Patrick von Platen committed
1158
                model_specific_kwargs=model_specific_kwargs,
1159
            )
thomwolf's avatar
thomwolf committed
1160
        else:
1161
1162
            output = self._generate_no_beam_search(
                input_ids,
1163
1164
1165
1166
1167
1168
1169
1170
1171
                cur_len=cur_len,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                no_repeat_ngram_size=no_repeat_ngram_size,
1172
                bad_words_ids=bad_words_ids,
1173
                bos_token_id=bos_token_id,
1174
                pad_token_id=pad_token_id,
1175
                decoder_start_token_id=decoder_start_token_id,
1176
                eos_token_id=eos_token_id,
1177
                batch_size=effective_batch_size,
1178
                encoder_outputs=encoder_outputs,
1179
                attention_mask=attention_mask,
1180
                use_cache=use_cache,
Patrick von Platen's avatar
Patrick von Platen committed
1181
                model_specific_kwargs=model_specific_kwargs,
1182
            )
thomwolf's avatar
thomwolf committed
1183
1184

        return output
thomwolf's avatar
thomwolf committed
1185

1186
1187
1188
1189
1190
    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1191
        min_length,
1192
1193
1194
1195
1196
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1197
        no_repeat_ngram_size,
1198
        bad_words_ids,
1199
        bos_token_id,
1200
        pad_token_id,
1201
        eos_token_id,
1202
        decoder_start_token_id,
1203
        batch_size,
1204
        encoder_outputs,
1205
        attention_mask,
1206
        use_cache,
Patrick von Platen's avatar
Patrick von Platen committed
1207
        model_specific_kwargs,
1208
    ):
thomwolf's avatar
thomwolf committed
1209
        """ Generate sequences for each example without beam search (num_beams == 1).
1210
1211
            All returned sequence are generated independantly.
        """
1212
        # length of generated sentences / unfinished sentences
thomwolf's avatar
thomwolf committed
1213
        unfinished_sents = input_ids.new(batch_size).fill_(1)
1214
        sent_lengths = input_ids.new(batch_size).fill_(max_length)
thomwolf's avatar
thomwolf committed
1215

1216
1217
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

thomwolf's avatar
thomwolf committed
1218
        while cur_len < max_length:
1219
            model_inputs = self.prepare_inputs_for_generation(
Patrick von Platen's avatar
Patrick von Platen committed
1220
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
1221
            )
Sam Shleifer's avatar
Sam Shleifer committed
1222

thomwolf's avatar
thomwolf committed
1223
1224
1225
            outputs = self(**model_inputs)
            next_token_logits = outputs[0][:, -1, :]

patrickvonplaten's avatar
patrickvonplaten committed
1226
            # if model has past, then set the past variable to speed up decoding
1227
            if self._use_cache(outputs, use_cache):
1228
1229
                past = outputs[1]

thomwolf's avatar
thomwolf committed
1230
1231
            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
Sam Shleifer's avatar
Sam Shleifer committed
1232
                self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
thomwolf's avatar
thomwolf committed
1233

Patrick von Platen's avatar
Patrick von Platen committed
1234
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1235
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
Patrick von Platen's avatar
Patrick von Platen committed
1236
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1237
1238
1239
1240
1241
1242
1243
1244
                banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
                for batch_idx in range(batch_size):
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

Patrick von Platen's avatar
Patrick von Platen committed
1245
                for batch_idx in range(batch_size):
1246
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1247

patrickvonplaten's avatar
patrickvonplaten committed
1248
            # set eos token prob to zero if min_length is not reached
1249
1250
            if eos_token_id is not None and cur_len < min_length:
                next_token_logits[:, eos_token_id] = -float("inf")
1251

thomwolf's avatar
thomwolf committed
1252
1253
            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
1254
                if temperature != 1.0:
thomwolf's avatar
thomwolf committed
1255
1256
1257
1258
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
1259
1260
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
thomwolf's avatar
thomwolf committed
1261
1262
            else:
                # Greedy decoding
1263
                next_token = torch.argmax(next_token_logits, dim=-1)
thomwolf's avatar
thomwolf committed
1264
1265

            # update generations and finished sentences
1266
1267
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
1268
1269
1270
1271
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

Patrick von Platen's avatar
Patrick von Platen committed
1272
            # add token and increase length by one
1273
            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
1274
            cur_len = cur_len + 1
1275

1276
1277
1278
1279
            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
Patrick von Platen's avatar
Patrick von Platen committed
1280
                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
1281
1282
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents.mul_((~eos_in_sents).long())
1283

thomwolf's avatar
thomwolf committed
1284
1285
1286
1287
            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

patrickvonplaten's avatar
patrickvonplaten committed
1288
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1289
            if self.config.is_encoder_decoder is False:
1290
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1291
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1292
                )
Patrick von Platen's avatar
Patrick von Platen committed
1293

1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
        # if there are different sentences lengths in the batch, some batches have to be padded
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
            # finished sents are filled with pad_token
            decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
        else:
            decoded = input_ids

        for hypo_idx, hypo in enumerate(input_ids):
            decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
1304

1305
        return decoded
thomwolf's avatar
thomwolf committed
1306

1307
1308
1309
1310
1311
    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
1312
        min_length,
1313
        do_sample,
Patrick von Platen's avatar
Patrick von Platen committed
1314
        early_stopping,
1315
1316
1317
1318
        temperature,
        top_k,
        top_p,
        repetition_penalty,
Patrick von Platen's avatar
Patrick von Platen committed
1319
        no_repeat_ngram_size,
1320
        bad_words_ids,
Patrick von Platen's avatar
Patrick von Platen committed
1321
        bos_token_id,
1322
        pad_token_id,
1323
        eos_token_id,
1324
        decoder_start_token_id,
1325
        batch_size,
1326
        num_return_sequences,
1327
1328
1329
        length_penalty,
        num_beams,
        vocab_size,
1330
        encoder_outputs,
Patrick von Platen's avatar
Patrick von Platen committed
1331
        attention_mask,
1332
        use_cache,
Patrick von Platen's avatar
Patrick von Platen committed
1333
        model_specific_kwargs,
1334
    ):
thomwolf's avatar
thomwolf committed
1335
        """ Generate sequences for each example with beam search.
1336
        """
thomwolf's avatar
thomwolf committed
1337
1338

        # generated hypotheses
1339
        generated_hyps = [
1340
1341
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
            for _ in range(batch_size)
1342
        ]
thomwolf's avatar
thomwolf committed
1343
1344
1345

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
patrickvonplaten's avatar
patrickvonplaten committed
1346
1347

        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
Patrick von Platen's avatar
Patrick von Platen committed
1348
1349
        if do_sample is False:
            beam_scores[:, 1:] = -1e9
1350
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
thomwolf's avatar
thomwolf committed
1351
1352

        # cache compute states
1353
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models
thomwolf's avatar
thomwolf committed
1354
1355
1356
1357
1358

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
1359
            model_inputs = self.prepare_inputs_for_generation(
Patrick von Platen's avatar
Patrick von Platen committed
1360
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
1361
            )
1362
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
1363
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)
1364

patrickvonplaten's avatar
patrickvonplaten committed
1365
            # if model has past, then set the past variable to speed up decoding
1366
            if self._use_cache(outputs, use_cache):
1367
                past = outputs[1]
thomwolf's avatar
thomwolf committed
1368

1369
1370
            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
1371
                self.enforce_repetition_penalty_(
Patrick von Platen's avatar
Patrick von Platen committed
1372
                    next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
1373
                )
thomwolf's avatar
thomwolf committed
1374

patrickvonplaten's avatar
patrickvonplaten committed
1375
1376
1377
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature

Patrick von Platen's avatar
Patrick von Platen committed
1378
            if self.config.is_encoder_decoder and do_sample is False:
1379
1380
1381
1382
1383
1384
                # TODO (PVP) still a bit hacky here - there might be a better solution
                next_token_logits = self.prepare_logits_for_generation(
                    next_token_logits, cur_len=cur_len, max_length=max_length
                )

            scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
patrickvonplaten's avatar
patrickvonplaten committed
1385
1386

            # set eos token prob to zero if min_length is not reached
1387
1388
            if eos_token_id is not None and cur_len < min_length:
                scores[:, eos_token_id] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1389

Patrick von Platen's avatar
Patrick von Platen committed
1390
            if no_repeat_ngram_size > 0:
patrickvonplaten's avatar
patrickvonplaten committed
1391
1392
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                num_batch_hypotheses = batch_size * num_beams
Patrick von Platen's avatar
Patrick von Platen committed
1393
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
1394
                banned_batch_tokens = calc_banned_ngram_tokens(
1395
1396
                    input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
                )
patrickvonplaten's avatar
patrickvonplaten committed
1397
                for i, banned_tokens in enumerate(banned_batch_tokens):
1398
                    scores[i, banned_tokens] = -float("inf")
Patrick von Platen's avatar
Patrick von Platen committed
1399

1400
1401
1402
1403
1404
1405
1406
            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

                for i, banned_tokens in enumerate(banned_tokens):
                    scores[i, banned_tokens] = -float("inf")

1407
1408
1409
            assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
                scores.shape, (batch_size * num_beams, vocab_size)
            )
1410

1411
            if do_sample:
1412
                _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
1413
                # Top-p/top-k filtering
1414
1415
                _scores = top_k_top_p_filtering(
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
1416
                )  # (batch_size * num_beams, vocab_size)
1417
1418
1419
1420
1421
                # re-organize to group the beam together to sample from all beam_idxs
                _scores = _scores.contiguous().view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)

1422
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
1423
1424
                probs = F.softmax(_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
1425
                # Compute next scores
1426
                next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
1427
1428
1429
                # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
                next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
patrickvonplaten's avatar
patrickvonplaten committed
1430

1431
            else:
1432
                next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
patrickvonplaten's avatar
patrickvonplaten committed
1433

1434
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
1435
1436
1437
                next_scores = next_scores.view(
                    batch_size, num_beams * vocab_size
                )  # (batch_size, num_beams * vocab_size)
1438

1439
                next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
thomwolf's avatar
thomwolf committed
1440

1441
            assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
thomwolf's avatar
thomwolf committed
1442
1443
1444
1445
1446

            # next batch beam content
            next_batch_beam = []

            # for each sentence
1447
            for batch_idx in range(batch_size):
thomwolf's avatar
thomwolf committed
1448
1449

                # if we are done with this sentence
1450
1451
1452
1453
1454
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                    assert (
1455
                        eos_token_id is not None and pad_token_id is not None
1456
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
thomwolf's avatar
thomwolf committed
1457
1458
1459
1460
1461
1462
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

1463
                # next tokens for this sentence
1464
1465
1466
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
                ):
1467
                    # get beam and token IDs
1468
1469
                    beam_id = beam_token_id // vocab_size
                    token_id = beam_token_id % vocab_size
thomwolf's avatar
thomwolf committed
1470

1471
                    effective_beam_id = batch_idx * num_beams + beam_id
1472
1473
                    # add to generated hypotheses if end of sentence or last iteration
                    if (eos_token_id is not None) and (token_id.item() == eos_token_id):
1474
1475
1476
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                        if is_beam_token_worse_than_top_num_beams:
patrickvonplaten's avatar
patrickvonplaten committed
1477
                            continue
1478
                        generated_hyps[batch_idx].add(
1479
                            input_ids[effective_beam_id].clone(), beam_token_score.item(),
1480
                        )
thomwolf's avatar
thomwolf committed
1481
                    else:
1482
                        # add next predicted token if it is not eos_token
1483
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
thomwolf's avatar
thomwolf committed
1484
1485
1486
1487
1488

                    # the beam for next step is full
                    if len(next_sent_beam) == num_beams:
                        break

patrickvonplaten's avatar
patrickvonplaten committed
1489
1490
1491
1492
1493
                # Check if were done so that we can save a pad step if all(done)
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                    next_scores[batch_idx].max().item(), cur_len=cur_len
                )

thomwolf's avatar
thomwolf committed
1494
                # update next beam content
1495
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
thomwolf's avatar
thomwolf committed
1496
                next_batch_beam.extend(next_sent_beam)
1497
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)
thomwolf's avatar
thomwolf committed
1498

patrickvonplaten's avatar
patrickvonplaten committed
1499
1500
1501
1502
            # stop when we are done with each sentence
            if all(done):
                break

thomwolf's avatar
thomwolf committed
1503
1504
1505
            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
1506
            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
thomwolf's avatar
thomwolf committed
1507
1508
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])

Patrick von Platen's avatar
Patrick von Platen committed
1509
            # re-order batch and update current length
thomwolf's avatar
thomwolf committed
1510
            input_ids = input_ids[beam_idx, :]
1511
            input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
1512
1513
            cur_len = cur_len + 1

1514
            # re-order internal states
1515
            if past is not None:
Sam Shleifer's avatar
Sam Shleifer committed
1516
                past = self._reorder_cache(past, beam_idx)
thomwolf's avatar
thomwolf committed
1517

patrickvonplaten's avatar
patrickvonplaten committed
1518
            # extend attention_mask for new generated input if only decoder
Patrick von Platen's avatar
Patrick von Platen committed
1519
            if self.config.is_encoder_decoder is False:
1520
                attention_mask = torch.cat(
patrickvonplaten's avatar
patrickvonplaten committed
1521
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1522
                )
Patrick von Platen's avatar
Patrick von Platen committed
1523

1524
        # finalize all open beam hypotheses and end to generated hypotheses
1525
        for batch_idx in range(batch_size):
1526
1527
            if done[batch_idx]:
                continue
1528

1529
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
1530
1531
            if eos_token_id is not None and all(
                (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
1532
1533
1534
1535
            ):
                assert torch.all(
                    next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
Patrick von Platen's avatar
Patrick von Platen committed
1536
                    next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
1537
1538
1539
1540
1541
1542
1543
1544
                )

            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[batch_idx].add(final_tokens, final_score)
thomwolf's avatar
thomwolf committed
1545

1546
1547
1548
        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
thomwolf's avatar
thomwolf committed
1549
1550

        # select the best hypotheses
1551
        sent_lengths = input_ids.new(output_batch_size)
thomwolf's avatar
thomwolf committed
1552
        best = []
thomwolf's avatar
thomwolf committed
1553

1554
        # retrieve best hypotheses
thomwolf's avatar
thomwolf committed
1555
        for i, hypotheses in enumerate(generated_hyps):
1556
1557
1558
1559
1560
1561
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                effective_batch_idx = output_num_return_sequences_per_batch * i + j
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[effective_batch_idx] = len(best_hyp)
                best.append(best_hyp)
thomwolf's avatar
thomwolf committed
1562

1563
1564
1565
        # shorter batches are filled with pad_token
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
1566
            sent_max_len = min(sent_lengths.max().item() + 1, max_length)
1567
            decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
1568
1569
1570
1571
1572

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
                decoded[i, : sent_lengths[i]] = hypo
                if sent_lengths[i] < max_length:
1573
                    decoded[i, sent_lengths[i]] = eos_token_id
1574
1575
1576
1577
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
thomwolf's avatar
thomwolf committed
1578

Patrick von Platen's avatar
Patrick von Platen committed
1579
        return decoded
thomwolf's avatar
thomwolf committed
1580

Sam Shleifer's avatar
Sam Shleifer committed
1581
    @staticmethod
1582
    def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
1583
        return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
Sam Shleifer's avatar
Sam Shleifer committed
1584

thomwolf's avatar
thomwolf committed
1585

1586
1587
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
patrickvonplaten's avatar
patrickvonplaten committed
1588
    if cur_len + 1 < no_repeat_ngram_size:
Patrick von Platen's avatar
Patrick von Platen committed
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
patrickvonplaten's avatar
patrickvonplaten committed
1601
        start_idx = cur_len + 1 - no_repeat_ngram_size
1602
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
Patrick von Platen's avatar
Patrick von Platen committed
1603
1604
1605
1606
1607
1608
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


1609
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
    banned_tokens = []

    def _tokens_match(prev_tokens, tokens):
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        if len(tokens) > len(prev_input_ids):
            # if bad word tokens are longer then prev input_ids they can't be equal
            return False

        if prev_tokens[-len(tokens) :] == tokens:
            # if tokens match
            return True
        else:
            return False

    for prev_input_ids_slice in prev_input_ids:
        banned_tokens_slice = []

        for banned_token_seq in bad_words_ids:
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
                bad_words_ids
            )

            if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
                # if tokens do not match continue
                continue

            banned_tokens_slice.append(banned_token_seq[-1])

        banned_tokens.append(banned_tokens_slice)

    return banned_tokens


1645
1646
1647
1648
1649
1650
1651
def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> Tensor:
thomwolf's avatar
thomwolf committed
1652
1653
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
thomwolf's avatar
thomwolf committed
1654
            logits: logits distribution shape (batch size, vocabulary size)
1655
1656
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
thomwolf's avatar
thomwolf committed
1657
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
thomwolf's avatar
thomwolf committed
1658
            Make sure we keep at least min_tokens_to_keep per batch example in the output
thomwolf's avatar
thomwolf committed
1659
1660
1661
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
thomwolf's avatar
thomwolf committed
1662
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
thomwolf's avatar
thomwolf committed
1663
1664
1665
1666
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

1667
    if top_p < 1.0:
thomwolf's avatar
thomwolf committed
1668
1669
1670
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

thomwolf's avatar
thomwolf committed
1671
        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
thomwolf's avatar
thomwolf committed
1672
        sorted_indices_to_remove = cumulative_probs > top_p
thomwolf's avatar
thomwolf committed
1673
1674
1675
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
thomwolf's avatar
thomwolf committed
1676
1677
1678
1679
1680
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
1681
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
thomwolf's avatar
thomwolf committed
1682
1683
        logits[indices_to_remove] = filter_value
    return logits
thomwolf's avatar
thomwolf committed
1684
1685
1686


class BeamHypotheses(object):
1687
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
thomwolf's avatar
thomwolf committed
1688
1689
1690
1691
1692
1693
        """
        Initialize n-best list of hypotheses.
        """
        self.max_length = max_length - 1  # ignoring bos_token
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
1694
1695
        self.num_beams = num_beams
        self.beams = []
thomwolf's avatar
thomwolf committed
1696
1697
1698
1699
1700
1701
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
1702
        return len(self.beams)
thomwolf's avatar
thomwolf committed
1703

thomwolf's avatar
thomwolf committed
1704
1705
1706
1707
1708
    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
1709
1710
1711
1712
1713
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp))
            if len(self) > self.num_beams:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
thomwolf's avatar
thomwolf committed
1714
1715
1716
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
thomwolf's avatar
thomwolf committed
1717

Sam Shleifer's avatar
Sam Shleifer committed
1718
    def is_done(self, best_sum_logprobs, cur_len=None):
thomwolf's avatar
thomwolf committed
1719
1720
1721
1722
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
Sam Shleifer's avatar
Sam Shleifer committed
1723

1724
        if len(self) < self.num_beams:
thomwolf's avatar
thomwolf committed
1725
1726
1727
1728
            return False
        elif self.early_stopping:
            return True
        else:
Sam Shleifer's avatar
Sam Shleifer committed
1729
1730
1731
1732
1733
            if cur_len is None:
                cur_len = self.max_length
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
            ret = self.worst_score >= cur_score
            return ret
thomwolf's avatar
thomwolf committed
1734
1735


thomwolf's avatar
thomwolf committed
1736
1737
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
thomwolf's avatar
thomwolf committed
1738
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
thomwolf's avatar
thomwolf committed
1739
1740
            Basically works like a Linear layer but the weights are transposed
        """
Julien Chaumond's avatar
Julien Chaumond committed
1741
        super().__init__()
thomwolf's avatar
thomwolf committed
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


thomwolf's avatar
thomwolf committed
1755
1756
class PoolerStartLogits(nn.Module):
    """ Compute SQuAD start_logits from sequence hidden states. """
1757

thomwolf's avatar
thomwolf committed
1758
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1759
        super().__init__()
thomwolf's avatar
thomwolf committed
1760
1761
1762
1763
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, p_mask=None):
        """ Args:
1764
1765
1766
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
                invalid position mask such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1767
        """
thomwolf's avatar
thomwolf committed
1768
1769
1770
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
1771
1772
1773
1774
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1775
1776
1777
1778
1779
1780
1781

        return x


class PoolerEndLogits(nn.Module):
    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
    """
1782

thomwolf's avatar
thomwolf committed
1783
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1784
        super().__init__()
thomwolf's avatar
thomwolf committed
1785
1786
1787
1788
1789
1790
1791
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
        """ Args:
1792
1793
1794
1795
1796
1797
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
1798
                position of the first token for the labeled span:
1799
1800
1801
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
1802
        """
1803
1804
1805
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1806
        if start_positions is not None:
1807
            slen, hsz = hidden_states.shape[-2:]
1808
1809
1810
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)
thomwolf's avatar
thomwolf committed
1811
1812
1813
1814
1815
1816
1817

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
1818
1819
1820
1821
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
1822
1823
1824
1825
1826
1827

        return x


class PoolerAnswerClass(nn.Module):
    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
1828

thomwolf's avatar
thomwolf committed
1829
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1830
        super().__init__()
thomwolf's avatar
thomwolf committed
1831
1832
1833
1834
1835
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
thomwolf's avatar
thomwolf committed
1851
        """
1852
        hsz = hidden_states.shape[-1]
1853
1854
1855
        assert (
            start_states is not None or start_positions is not None
        ), "One of start_states, start_positions should be not None"
thomwolf's avatar
thomwolf committed
1856
        if start_positions is not None:
1857
1858
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1859
1860

        if cls_index is not None:
1861
1862
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1863
        else:
1864
            cls_token_state = hidden_states[:, -1, :]  # shape (bsz, hsz)
thomwolf's avatar
thomwolf committed
1865
1866
1867
1868
1869
1870
1871
1872
1873

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x


class SQuADHead(nn.Module):
1874
1875
1876
    r""" A SQuAD head inspired by XLNet.

    Parameters:
1877
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896

    Inputs:
        **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
            hidden states of sequence tokens
        **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the first token for the labeled span.
        **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the last token for the labeled span.
        **cls_index**: torch.LongTensor of shape ``(batch_size,)``
            position of the CLS token. If None, take the last token.
        **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
            Whether the question has a possible answer in the paragraph or not.
        **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
            Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
            1.0 means token should be masked.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
thomwolf's avatar
thomwolf committed
1897
        **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1898
1899
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1900
        **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1901
1902
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
            Indices for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1903
        **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1904
1905
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1906
        **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1907
1908
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
1909
        **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1910
1911
            ``torch.FloatTensor`` of shape ``(batch_size,)``
            Log probabilities for the ``is_impossible`` label of the answers.
thomwolf's avatar
thomwolf committed
1912
    """
1913

thomwolf's avatar
thomwolf committed
1914
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1915
        super().__init__()
thomwolf's avatar
thomwolf committed
1916
1917
1918
1919
1920
1921
1922
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

1923
    def forward(
Patrick von Platen's avatar
Patrick von Platen committed
1924
        self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1925
    ):
thomwolf's avatar
thomwolf committed
1926
1927
        outputs = ()

thomwolf's avatar
thomwolf committed
1928
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
thomwolf's avatar
thomwolf committed
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index, is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5
1952
1953

            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1954
1955
1956
1957

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
            start_log_probs = F.softmax(start_logits, dim=-1)  # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(
                start_log_probs, self.start_n_top, dim=-1
            )  # shape (bsz, start_n_top)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, start_n_top, hsz)
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states
            )  # shape (bsz, slen, start_n_top, hsz)
thomwolf's avatar
thomwolf committed
1970
1971
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1972
            end_log_probs = F.softmax(end_logits, dim=1)  # shape (bsz, slen, start_n_top)
thomwolf's avatar
thomwolf committed
1973

1974
1975
1976
            end_top_log_probs, end_top_index = torch.topk(
                end_log_probs, self.end_n_top, dim=1
            )  # shape (bsz, end_n_top, start_n_top)
thomwolf's avatar
thomwolf committed
1977
1978
1979
1980
1981
1982
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)

Patrick von Platen's avatar
Patrick von Platen committed
1983
            outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
thomwolf's avatar
thomwolf committed
1984
1985

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1986
        # or (if labels are provided) (total_loss,)
thomwolf's avatar
thomwolf committed
1987
1988
1989
1990
        return outputs


class SequenceSummary(nn.Module):
thomwolf's avatar
thomwolf committed
1991
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
thomwolf's avatar
thomwolf committed
1992
1993
1994
1995
1996
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
thomwolf's avatar
thomwolf committed
1997
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
thomwolf's avatar
thomwolf committed
1998
1999
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
2000
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
2001
            summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
2002
2003
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
thomwolf's avatar
thomwolf committed
2004
    """
2005

2006
    def __init__(self, config: PretrainedConfig):
Julien Chaumond's avatar
Julien Chaumond committed
2007
        super().__init__()
thomwolf's avatar
thomwolf committed
2008

2009
        self.summary_type = getattr(config, "summary_type", "last")
2010
        if self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
2011
2012
2013
2014
2015
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

thomwolf's avatar
thomwolf committed
2016
        self.summary = Identity()
2017
2018
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
2019
                num_classes = config.num_labels
thomwolf's avatar
thomwolf committed
2020
2021
2022
2023
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

2024
        activation_string = getattr(config, "summary_activation", None)
2025
        self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
thomwolf's avatar
thomwolf committed
2026

thomwolf's avatar
thomwolf committed
2027
        self.first_dropout = Identity()
2028
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
2029
2030
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

thomwolf's avatar
thomwolf committed
2031
        self.last_dropout = Identity()
2032
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
2033
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
thomwolf's avatar
thomwolf committed
2034

thomwolf's avatar
thomwolf committed
2035
    def forward(self, hidden_states, cls_index=None):
2036
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
thomwolf's avatar
thomwolf committed
2037
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
thomwolf's avatar
thomwolf committed
2038
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
thomwolf's avatar
thomwolf committed
2039
                if summary_type == 'cls_index' and cls_index is None:
thomwolf's avatar
thomwolf committed
2040
2041
                    we take the last token of the sequence as classification token
        """
2042
        if self.summary_type == "last":
thomwolf's avatar
thomwolf committed
2043
            output = hidden_states[:, -1]
2044
        elif self.summary_type == "first":
thomwolf's avatar
thomwolf committed
2045
            output = hidden_states[:, 0]
2046
        elif self.summary_type == "mean":
thomwolf's avatar
thomwolf committed
2047
            output = hidden_states.mean(dim=1)
2048
        elif self.summary_type == "cls_index":
thomwolf's avatar
thomwolf committed
2049
            if cls_index is None:
Patrick von Platen's avatar
Patrick von Platen committed
2050
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
thomwolf's avatar
thomwolf committed
2051
            else:
thomwolf's avatar
thomwolf committed
2052
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
2053
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
thomwolf's avatar
thomwolf committed
2054
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
2055
2056
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
thomwolf's avatar
thomwolf committed
2057
2058
            raise NotImplementedError

2059
        output = self.first_dropout(output)
thomwolf's avatar
thomwolf committed
2060
2061
        output = self.summary(output)
        output = self.activation(output)
2062
        output = self.last_dropout(output)
thomwolf's avatar
thomwolf committed
2063
2064
2065
2066

        return output


Sam Shleifer's avatar
Sam Shleifer committed
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
def create_position_ids_from_input_ids(input_ids, padding_idx):
    """ Replace non-padding symbols with their position numbers. Position numbers begin at
    padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
    `utils.make_positions`.

    :param torch.Tensor x:
    :return torch.Tensor:
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    mask = input_ids.ne(padding_idx).int()
sshleifer's avatar
sshleifer committed
2077
2078
    incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
    return incremental_indices.long() + padding_idx
Sam Shleifer's avatar
Sam Shleifer committed
2079
2080


2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
def prune_linear_layer(layer, index, dim=0):
    """ Prune a linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140


def prune_layer(layer, index, dim=None):
    """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    if isinstance(layer, nn.Linear):
        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
    elif isinstance(layer, Conv1D):
        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
    else:
        raise ValueError("Can't prune layer of class {}".format(layer.__class__))
Patrick von Platen's avatar
Patrick von Platen committed
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203


def apply_chunking_to_forward(
    chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
) -> torch.Tensor:
    """
    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
    It then applies a layer `forward_fn` to each chunk independently to save memory.
    If the `forward_fn` is independent across the `chunk_dim` this function will yield the
    same result as not applying it.

    Args:
        chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
        chunk_dim: int - the dimension over which the input_tensors should be chunked
        forward_fn: fn - the forward fn of the model
        input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
    Returns:
        a Tensor with the same shape the foward_fn would have given if applied


    Examples::

        # rename the usual forward() fn to forward_chunk()
        def forward_chunk(self, hidden_states):
            hidden_states = self.decoder(hidden_states)
            return hidden_states

        # implement a chunked forward function
        def forward(self, hidden_states):
            return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
    """

    assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
    tensor_shape = input_tensors[0].shape
    assert all(
        input_tensor.shape == tensor_shape for input_tensor in input_tensors
    ), "All input tenors have to be of the same shape"

    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
    assert num_args_in_forward_chunk_fn == len(
        input_tensors
    ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
        num_args_in_forward_chunk_fn, len(input_tensors)
    )

    if chunk_size > 0:
        assert (
            input_tensors[0].shape[chunk_dim] % chunk_size == 0
        ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
            input_tensors[0][chunk_dim], chunk_size
        )

        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size

        # chunk input tensor into tuples
        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        # apply forward fn to every tuple
        output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
        # concatenate output at same dimension
        return torch.cat(output_chunks, dim=chunk_dim)

    return forward_fn(*input_tensors)