modeling_utils.py 53.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

18
19
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
20

21
22
import copy
import json
23
24
import logging
import os
thomwolf's avatar
thomwolf committed
25
from io import open
R茅mi Louf's avatar
R茅mi Louf committed
26
import warnings
27

28
import six
29
30
import torch
from torch import nn
31
32
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
R茅mi Louf's avatar
R茅mi Louf committed
33
from tqdm import trange
34

35
from .configuration_utils import PretrainedConfig
thomwolf's avatar
thomwolf committed
36
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
37
38
39
40

logger = logging.getLogger(__name__)


thomwolf's avatar
thomwolf committed
41
42
43
44
45
46
47
48
49
50
51
52
53
try:
    from torch.nn import Identity
except ImportError:
    # Older PyTorch compatibility
    class Identity(nn.Module):
        r"""A placeholder identity operator that is argument-insensitive.
        """
        def __init__(self, *args, **kwargs):
            super(Identity, self).__init__()

        def forward(self, input):
            return input

54
class PreTrainedModel(nn.Module):
55
56
    r""" Base class for all models.

57
        :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
58
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
59
60

        Class attributes (overridden by derived classes):
61
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
62
63
64
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

65
66
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
67
68
69
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
70
    """
71
    config_class = None
72
73
74
75
76
77
78
79
80
81
82
83
84
    pretrained_model_archive_map = {}
    load_tf_weights = lambda model, config, path: None
    base_model_prefix = ""

    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedModel, self).__init__()
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
thomwolf's avatar
thomwolf committed
85

thomwolf's avatar
thomwolf committed
86
        # Save config in model
87
88
        self.config = config

89
90
91
    @property
    def base_model(self):
        return getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
92

thomwolf's avatar
thomwolf committed
93
94
    def get_input_embeddings(self):
        """ Get model's input embeddings
thomwolf's avatar
thomwolf committed
95
        """
96
        base_model = getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
97
98
99
100
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
101

thomwolf's avatar
thomwolf committed
102
103
104
105
106
107
108
109
    def set_input_embeddings(self, value):
        """ Set model's input embeddings
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
110

thomwolf's avatar
thomwolf committed
111
112
113
114
    def get_output_embeddings(self):
        """ Get model's output embeddings
            Return None if the model doesn't have output embeddings
        """
115
        return None  # Overwrite for models with output embeddings
thomwolf's avatar
thomwolf committed
116

117
118
119
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
thomwolf's avatar
thomwolf committed
120
        """
thomwolf's avatar
thomwolf committed
121
122
123
        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
thomwolf's avatar
thomwolf committed
124

125
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
thomwolf's avatar
thomwolf committed
126
127
128
        """ Tie or clone module weights depending of weither we are using TorchScript or not
        """
        if self.config.torchscript:
129
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
thomwolf's avatar
thomwolf committed
130
        else:
131
            output_embeddings.weight = input_embeddings.weight
thomwolf's avatar
thomwolf committed
132

133
134
135
136
        if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
            output_embeddings.bias.data = torch.nn.functional.pad(
                output_embeddings.bias.data,
                (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
137
138
139
                'constant',
                0
            )
140
141
        if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
            output_embeddings.out_features = input_embeddings.num_embeddings
142

thomwolf's avatar
thomwolf committed
143
144
    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
145
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
thomwolf's avatar
thomwolf committed
146

147
148
149
        Arguments:

            new_num_tokens: (`optional`) int:
150
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
151
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
thomwolf's avatar
thomwolf committed
152

thomwolf's avatar
thomwolf committed
153
        Return: ``torch.nn.Embeddings``
154
            Pointer to the input tokens Embeddings Module of the model
thomwolf's avatar
thomwolf committed
155
156
        """
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
thomwolf's avatar
thomwolf committed
157
158
159
        model_embeds = base_model._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds
thomwolf's avatar
thomwolf committed
160
161
162
163
164
165
166
167
168

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens

        # Tie weights again if needed
        if hasattr(self, 'tie_weights'):
            self.tie_weights()

thomwolf's avatar
thomwolf committed
169
170
        return model_embeds

171
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
172
173
174
175
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_input_embeddings(new_embeddings)
        return self.get_input_embeddings()
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
        """ Build a resized Embedding Module from a provided token Embedding Module.
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
        Return: ``torch.nn.Embeddings``
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
        if new_num_tokens is None:
            return old_embeddings

        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        if old_num_tokens == new_num_tokens:
            return old_embeddings

        # Build new embeddings
        new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        new_embeddings.to(old_embeddings.weight.device)

        # initialize all new embeddings (in particular added tokens)
        self._init_weights(new_embeddings)

        # Copy word embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        return new_embeddings

211
212
213
214
215
216
217
218
219
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

220
221
222
        # Tie weights if needed
        self.tie_weights()

thomwolf's avatar
thomwolf committed
223
224
225
        # Initialize decoding head if we have output embeddings
        

thomwolf's avatar
thomwolf committed
226
227
    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.
228
229
230
231

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
232
                E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
233
        """
234
        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
235
        for layer, heads in heads_to_prune.items():
236
237
238
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON

239
        self.base_model._prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
240

241
    def save_pretrained(self, save_directory):
242
        """ Save a model and its configuration file to a directory, so that it
243
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
244
245
246
        """
        assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"

Julien Chaumond's avatar
Julien Chaumond committed
247
        # Only save the model itself if we are using distributed training
248
249
        model_to_save = self.module if hasattr(self, 'module') else self

thomwolf's avatar
thomwolf committed
250
251
252
        # Save configuration file
        model_to_save.config.save_pretrained(save_directory)

253
254
255
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
thomwolf's avatar
thomwolf committed
256
        logger.info("Model weights saved in {}".format(output_model_file))
257

258
    @classmethod
259
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
260
261
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

262
263
264
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

265
266
267
268
269
        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

270
271
272
273
        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
274
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
275
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
thomwolf's avatar
thomwolf committed
276
                - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
277
278
279
280

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

281
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
282
283
284
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
285
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
286
287
288
289
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
290
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
291
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
292
293

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
294
295
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
296

297
298
299
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

300
301
302
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

303
304
305
306
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

307
            output_loading_info: (`optional`) boolean:
thomwolf's avatar
thomwolf committed
308
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
309
310
311
312
313

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
314
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
315
316

        Examples::
thomwolf's avatar
thomwolf committed
317

thomwolf's avatar
thomwolf committed
318
319
320
321
322
323
324
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
325

326
        """
LysandreJik's avatar
LysandreJik committed
327
328
329
330
        if "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path:
            logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
                           "https://github.com/google-research/google-research/issues/119 for more information.")

thomwolf's avatar
thomwolf committed
331
        config = kwargs.pop('config', None)
thomwolf's avatar
thomwolf committed
332
333
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
thomwolf's avatar
thomwolf committed
334
        from_tf = kwargs.pop('from_tf', False)
335
        force_download = kwargs.pop('force_download', False)
336
        resume_download = kwargs.pop('resume_download', False)
337
        proxies = kwargs.pop('proxies', None)
thomwolf's avatar
thomwolf committed
338
        output_loading_info = kwargs.pop('output_loading_info', False)
thomwolf's avatar
thomwolf committed
339
340

        # Load config
thomwolf's avatar
thomwolf committed
341
        if config is None:
342
343
            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path, *model_args,
344
                cache_dir=cache_dir, return_unused_kwargs=True,
345
                force_download=force_download,
346
                resume_download=resume_download,
347
                proxies=proxies,
348
                **kwargs
349
350
351
            )
        else:
            model_kwargs = kwargs
352

thomwolf's avatar
thomwolf committed
353
        # Load model
thomwolf's avatar
thomwolf committed
354
        if pretrained_model_name_or_path is not None:
355
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
356
357
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
358
359
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
thomwolf's avatar
thomwolf committed
360
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
thomwolf's avatar
thomwolf committed
361
362
363
364
365
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
366
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
thomwolf's avatar
thomwolf committed
367
                else:
thomwolf's avatar
thomwolf committed
368
                    raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
thomwolf's avatar
thomwolf committed
369
                        [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
thomwolf's avatar
thomwolf committed
370
                        pretrained_model_name_or_path))
371
372
            elif os.path.isfile(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
thomwolf's avatar
thomwolf committed
373
            else:
374
375
376
                assert from_tf, "Error finding file {}, no file or TF 1.X checkpoint found".format(pretrained_model_name_or_path)
                archive_file = pretrained_model_name_or_path + ".index"

thomwolf's avatar
thomwolf committed
377
378
            # redirect to the cache, if necessary
            try:
379
380
                resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
                                                    proxies=proxies, resume_download=resume_download)
thomwolf's avatar
thomwolf committed
381
            except EnvironmentError:
thomwolf's avatar
thomwolf committed
382
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
383
384
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                            archive_file)
thomwolf's avatar
thomwolf committed
385
                else:
thomwolf's avatar
thomwolf committed
386
387
388
                    msg = "Model name '{}' was not found in model name list ({}). " \
                        "We assumed '{}' was a path or url to model weight files named one of {} but " \
                        "couldn't find any such file at this path or url.".format(
thomwolf's avatar
thomwolf committed
389
390
                            pretrained_model_name_or_path,
                            ', '.join(cls.pretrained_model_archive_map.keys()),
thomwolf's avatar
thomwolf committed
391
392
393
394
                            archive_file,
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
                raise EnvironmentError(msg)

thomwolf's avatar
thomwolf committed
395
396
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
397
            else:
thomwolf's avatar
thomwolf committed
398
399
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
400
        else:
thomwolf's avatar
thomwolf committed
401
            resolved_archive_file = None
402
403

        # Instantiate model.
404
        model = cls(config, *model_args, **model_kwargs)
thomwolf's avatar
thomwolf committed
405

406
407
        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')
408

409
410
411
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
412
413
414
415
416
417
418
419

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
420
                    from transformers import load_tf2_checkpoint_in_pytorch_model
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
436
437
                if key == 'lm_head.decoder.weight':
                    new_key = 'lm_head.weight'
438
439
440
441
442
443
444
445
446
447
448
449
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

450
451
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
            if len(missing_keys) > 0:
                logger.info("Weights of {} not initialized from pretrained model: {}".format(
                    model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info("Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                                model.__class__.__name__, "\n\t".join(error_msgs)))
478

thomwolf's avatar
thomwolf committed
479
        if hasattr(model, 'tie_weights'):
480
481
            model.tie_weights()  # make sure word embedding weights are still tied

482
483
484
        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

thomwolf's avatar
thomwolf committed
485
486
487
488
        if output_loading_info:
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
            return model, loading_info

489
490
        return model

thomwolf's avatar
thomwolf committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None,
                 temperature=None, top_k=None, top_p=None, repetition_penalty=None,
                 **model_kwargs):
        """ Generic sequence generator for single-stack models with a LM head.

        The method currently supports greedy decoding and sampling. See the
        documentation of the `Sampler` class for more information about the
        parameters related to sampling.

        Params:
            **input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
                The sequence used as a prompt for the generation. If `None` the method initializes
                it as an empty `torch.LongTensor` of shape (1,)
            **length**: (`optional`) int
                The length of the sequence to be generated.
            **do_sample**: (`optional`) bool
                If set to `False` we use greedy decoding; otherwise sampling.
            **temperature**: (`optional`) float
                The value used to module the next token probabilities.
            **k**: (`optional`) int
                The parameter used for k-filtering.
            **p**: (`optional`) float
                The parameter for nucleus sampling. Must be between 0 and 1.
            **repetition_penalty**: (`optional`) float
                The parameter for repetition penalty.
        """

        if input_ids is None:
            input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)

        # We cannot generate if the model does not have a LM head
        if self.get_output_embeddings() is None:
            raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")

        sampler_config = {
            "k": k,
            "p": p,
            "do_sample": do_sample,
            "temperature": temperature,
            "repetition_penalty": repetition_penalty,
        }

        sampler = Sampler(**sampler_config)
        generated_sequence = input_ids
        for _ in trange(length):
            arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
            outputs = self(**arguments)
            next_tokens_logits = outputs[0][:, -1, :]
            next_tokens = sampler.get_one_token(
                next_tokens_logits, generated_sequence
            )
            generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)

        return generated_sequence.squeeze(0)

    def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
        return model_kwargs.update({"input_ids": input_ids})


class Sampler(object):
    r""" Sampler is used to generate sequences of ids from logit inputs.

    Greedy decoding, which consists in chosing the most probable token at each
    step, is the default behaviour. Sampling with varying temperature, top_k
    and nucleus filtering is also implemented.

    Attributes:
        **device**: ``torch.device``
            Device on which the computations will be run.
        **do_sample**: bool
            Whether to sample or do greedy decoding.
        **k**: int between 0 and vocab_size
            Parameter for the top-k filtering
        **p**: float between 0 and 1
            Parameter for the nucleus filtering
        **temperature**: strictly positive float
            Parameter used to modulate the distribution over ids. Low temperatures
            put more emphasis on highly probably token while high temperatures tend
            to smooth the probability distribution.
        **repetition_penalty**: strictly postitive float
            The penalty applied to repeating ids
    """

    def __init__(
        self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
    ):
        self.k = k
        self.p = p
        self.do_sample = do_sample
        self.temperature = temperature
        self.repetition_penalty = repetition_penalty

        self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False

        if self.p > 1:
            warnings.warn(
                """You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
                However p is a probability and its value must lie between 0 and 1. In effect, no filtering
                will be applied. If this is not the behavior you expect, change the value of p.""".format(
                    self.p
                )
            )

    def get_one_token(self, next_token_logits, past_sequence):
        logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
        if self.do_sample:
            logits = self.apply_temperature(logits)
            logits = self.apply_top_k_filter(logits)
            logits = self.apply_nucleus_filter(logits)
            return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
        return torch.argmax(logits, dim=-1).unsqueeze(-1)

    def apply_repetition_penalty(self, logits, past_sequence):
        """ Apply a penalty to tokens that appear more than once in the
        generated sequence.

        .. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
           language model for controllable generation." arXiv preprint
           arXiv:1909.05858 (2019).
        """
        if self.do_apply_repetition_penalty:
            generated_token_idx = set(past_sequence[0].tolist())
            for token_idx in generated_token_idx:
                logits[0, token_idx] /= self.repetition_penalty
        return logits

    def apply_temperature(self, logits):
        """ Shape the tokens' distribution through temperature. The higher the value
        of the temperature, the more skewed towards high probability events the
        distribution is.

        .. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
        MIT press, 2016.
        """
        # when dividing a float by 0, torch returns inf which in turns breaks the
        # multinomial with an error message that is not very helpful. It is better
        # for the user to break the execution and explain why.
        if self.temperature == 0:
            raise ZeroDivisionError(
                """You are trying to sample with a temperature equal to 0.
                If you wanted to do greedy sampling, set instead `do_sample` to False.
                Otherwise set the temperature to a value different from 0."""
            )
        return logits / self.temperature

    def apply_top_k_filter(self, logits):
        """ Use the probability distribution of the tokens to determine the set
        to be sampled from. Specifically we select the set of size k such that
        the sum of its items' probabilities is maximum.

        .. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
        story generation." arXiv preprint arXiv:1805.04833 (2018).
        """
        if self.k > 0:
            vocabulary_size = logits.size(-1)
            if self.k > vocabulary_size:
                warnings.warn(
                    """You provided a value for k ({}) that is larger than the vocabulary size ({}).
                    We adjusted k's value to the vocabulary size; if that was what you intended to do
                    we recommend setting k to 0 instead. It this is not the behavior you expected,
                    choose a value of k that is smaller than the vocabulary size.""".format(
                        self.k, vocabulary_size
                    )
                )
                self.k = vocabulary_size

            indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
            logits[indices_to_remove] = -float("Inf")

        return logits

    def apply_nucleus_filter(self, logits):
        """ Use the probability distribution of the tokens to determine the set
        to be sampled from. Specifically, choose the smallest set such that the
        sum of its items' probabilities is greater than a number p in [0,1].

        .. Holtzman, Ari, et al. "The curious case of neural text
           degeneration." arXiv preprint arXiv:1904.09751 (2019).
        """
        if self.p > 0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            sorted_probabilities = F.softmax(sorted_logits, dim=-1)
            cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)

            # Remove tokens with cumulative probability above the threshold,
            # but keep the first token above the threshold.
            sorted_indices_to_remove = cumulative_probabilities > self.p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0

            # scatter sorted tensors to original indexing
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=-1, index=sorted_indices, src=sorted_indices_to_remove
            )
            logits[indices_to_remove] = -float("Inf")

        return logits

689

thomwolf's avatar
thomwolf committed
690
691
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
thomwolf's avatar
thomwolf committed
692
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
thomwolf's avatar
thomwolf committed
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
            Basically works like a Linear layer but the weights are transposed
        """
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


thomwolf's avatar
thomwolf committed
709
710
class PoolerStartLogits(nn.Module):
    """ Compute SQuAD start_logits from sequence hidden states. """
thomwolf's avatar
thomwolf committed
711
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
712
713
714
715
716
        super(PoolerStartLogits, self).__init__()
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, p_mask=None):
        """ Args:
717
718
719
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
                invalid position mask such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
720
        """
thomwolf's avatar
thomwolf committed
721
722
723
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
724
725
726
727
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

        return x


class PoolerEndLogits(nn.Module):
    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
    """
    def __init__(self, config):
        super(PoolerEndLogits, self).__init__()
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
        """ Args:
744
745
746
747
748
749
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
750
                position of the first token for the labeled span:
751
752
753
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
754
755
756
        """
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
757
            slen, hsz = hidden_states.shape[-2:]
thomwolf's avatar
thomwolf committed
758
759
760
761
762
763
764
765
766
767
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
768
769
770
771
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
772
773
774
775
776
777
778
779
780
781
782
783
784

        return x


class PoolerAnswerClass(nn.Module):
    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
    def __init__(self, config):
        super(PoolerAnswerClass, self).__init__()
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
thomwolf's avatar
thomwolf committed
800
        """
801
        hsz = hidden_states.shape[-1]
thomwolf's avatar
thomwolf committed
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)

        if cls_index is not None:
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
        else:
            cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x


class SQuADHead(nn.Module):
821
822
823
    r""" A SQuAD head inspired by XLNet.

    Parameters:
824
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843

    Inputs:
        **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
            hidden states of sequence tokens
        **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the first token for the labeled span.
        **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the last token for the labeled span.
        **cls_index**: torch.LongTensor of shape ``(batch_size,)``
            position of the CLS token. If None, take the last token.
        **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
            Whether the question has a possible answer in the paragraph or not.
        **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
            Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
            1.0 means token should be masked.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
thomwolf's avatar
thomwolf committed
844
        **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
845
846
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
847
        **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
848
849
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
            Indices for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
850
        **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
851
852
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
853
        **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
854
855
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
856
        **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
857
858
            ``torch.FloatTensor`` of shape ``(batch_size,)``
            Log probabilities for the ``is_impossible`` label of the answers.
thomwolf's avatar
thomwolf committed
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    """
    def __init__(self, config):
        super(SQuADHead, self).__init__()
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

    def forward(self, hidden_states, start_positions=None, end_positions=None,
                cls_index=None, is_impossible=None, p_mask=None):
        outputs = ()

thomwolf's avatar
thomwolf committed
873
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
thomwolf's avatar
thomwolf committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index, is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5
897
898

            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
899
900
901
902
903
904
905

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
            start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
906
907
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
thomwolf's avatar
thomwolf committed
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
            end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)

            end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)

            outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
925
        # or (if labels are provided) (total_loss,)
thomwolf's avatar
thomwolf committed
926
927
928
929
        return outputs


class SequenceSummary(nn.Module):
thomwolf's avatar
thomwolf committed
930
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
thomwolf's avatar
thomwolf committed
931
932
933
934
935
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
thomwolf's avatar
thomwolf committed
936
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
thomwolf's avatar
thomwolf committed
937
938
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
939
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
940
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
941
942
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
thomwolf's avatar
thomwolf committed
943
944
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
945
946
        super(SequenceSummary, self).__init__()

947
        self.summary_type = config.summary_type if hasattr(config, 'summary_type') else 'last'
thomwolf's avatar
thomwolf committed
948
        if self.summary_type == 'attn':
thomwolf's avatar
thomwolf committed
949
950
951
952
953
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

thomwolf's avatar
thomwolf committed
954
        self.summary = Identity()
thomwolf's avatar
thomwolf committed
955
        if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
956
957
            if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
thomwolf's avatar
thomwolf committed
958
959
960
961
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

thomwolf's avatar
thomwolf committed
962
        self.activation = Identity()
thomwolf's avatar
thomwolf committed
963
964
965
        if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
            self.activation = nn.Tanh()

thomwolf's avatar
thomwolf committed
966
        self.first_dropout = Identity()
967
968
969
        if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

thomwolf's avatar
thomwolf committed
970
        self.last_dropout = Identity()
971
972
        if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
thomwolf's avatar
thomwolf committed
973

thomwolf's avatar
thomwolf committed
974
    def forward(self, hidden_states, cls_index=None):
975
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
thomwolf's avatar
thomwolf committed
976
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
thomwolf's avatar
thomwolf committed
977
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
thomwolf's avatar
thomwolf committed
978
                if summary_type == 'cls_index' and cls_index is None:
thomwolf's avatar
thomwolf committed
979
980
981
982
983
984
985
986
                    we take the last token of the sequence as classification token
        """
        if self.summary_type == 'last':
            output = hidden_states[:, -1]
        elif self.summary_type == 'first':
            output = hidden_states[:, 0]
        elif self.summary_type == 'mean':
            output = hidden_states.mean(dim=1)
thomwolf's avatar
thomwolf committed
987
988
989
        elif self.summary_type == 'cls_index':
            if cls_index is None:
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
thomwolf's avatar
thomwolf committed
990
            else:
thomwolf's avatar
thomwolf committed
991
992
993
994
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
thomwolf's avatar
thomwolf committed
995
996
997
        elif self.summary_type == 'attn':
            raise NotImplementedError

998
        output = self.first_dropout(output)
thomwolf's avatar
thomwolf committed
999
1000
        output = self.summary(output)
        output = self.activation(output)
1001
        output = self.last_dropout(output)
thomwolf's avatar
thomwolf committed
1002
1003
1004
1005

        return output


1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
def prune_linear_layer(layer, index, dim=0):
    """ Prune a linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065


def prune_layer(layer, index, dim=None):
    """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    if isinstance(layer, nn.Linear):
        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
    elif isinstance(layer, Conv1D):
        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
    else:
        raise ValueError("Can't prune layer of class {}".format(layer.__class__))