modeling_utils.py 44.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

18
19
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
20

21
22
import copy
import json
23
24
import logging
import os
thomwolf's avatar
thomwolf committed
25
from io import open
26

27
import six
28
29
import torch
from torch import nn
30
31
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
32

33
from .configuration_utils import PretrainedConfig
thomwolf's avatar
thomwolf committed
34
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
35
36
37
38

logger = logging.getLogger(__name__)


thomwolf's avatar
thomwolf committed
39
40
41
42
43
44
45
46
47
48
49
50
51
try:
    from torch.nn import Identity
except ImportError:
    # Older PyTorch compatibility
    class Identity(nn.Module):
        r"""A placeholder identity operator that is argument-insensitive.
        """
        def __init__(self, *args, **kwargs):
            super(Identity, self).__init__()

        def forward(self, input):
            return input

52
class PreTrainedModel(nn.Module):
53
54
    r""" Base class for all models.

55
        :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
Julien Chaumond's avatar
Julien Chaumond committed
56
        as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
57
58

        Class attributes (overridden by derived classes):
59
            - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
60
61
62
            - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
            - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:

63
64
                - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
                - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
65
66
67
                - ``path``: a path (string) to the TensorFlow checkpoint.

            - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
68
    """
69
    config_class = None
70
71
72
73
74
75
76
77
78
79
80
81
82
    pretrained_model_archive_map = {}
    load_tf_weights = lambda model, config, path: None
    base_model_prefix = ""

    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedModel, self).__init__()
        if not isinstance(config, PretrainedConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
thomwolf's avatar
thomwolf committed
83
        # Save config in model
84
85
        self.config = config

86
87
88
    @property
    def base_model(self):
        return getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
89

thomwolf's avatar
thomwolf committed
90
91
    def get_input_embeddings(self):
        """ Get model's input embeddings
thomwolf's avatar
thomwolf committed
92
        """
93
        base_model = getattr(self, self.base_model_prefix, self)
thomwolf's avatar
thomwolf committed
94
95
96
97
        if base_model is not self:
            return base_model.get_input_embeddings()
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
98

thomwolf's avatar
thomwolf committed
99
100
101
102
103
104
105
106
    def set_input_embeddings(self, value):
        """ Set model's input embeddings
        """
        base_model = getattr(self, self.base_model_prefix, self)
        if base_model is not self:
            base_model.set_input_embeddings(value)
        else:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
107

thomwolf's avatar
thomwolf committed
108
109
110
111
    def get_output_embeddings(self):
        """ Get model's output embeddings
            Return None if the model doesn't have output embeddings
        """
112
        return None  # Overwrite for models with output embeddings
thomwolf's avatar
thomwolf committed
113

114
115
116
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
thomwolf's avatar
thomwolf committed
117
        """
thomwolf's avatar
thomwolf committed
118
119
120
        output_embeddings = self.get_output_embeddings()
        if output_embeddings is not None:
            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
thomwolf's avatar
thomwolf committed
121

122
    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
thomwolf's avatar
thomwolf committed
123
124
125
        """ Tie or clone module weights depending of weither we are using TorchScript or not
        """
        if self.config.torchscript:
126
            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
thomwolf's avatar
thomwolf committed
127
        else:
128
            output_embeddings.weight = input_embeddings.weight
thomwolf's avatar
thomwolf committed
129

130
131
132
133
        if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None:
            output_embeddings.bias.data = torch.nn.functional.pad(
                output_embeddings.bias.data,
                (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
134
135
136
                'constant',
                0
            )
137
138
        if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'):
            output_embeddings.out_features = input_embeddings.num_embeddings
139

thomwolf's avatar
thomwolf committed
140
141
    def resize_token_embeddings(self, new_num_tokens=None):
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
142
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
thomwolf's avatar
thomwolf committed
143

144
145
146
        Arguments:

            new_num_tokens: (`optional`) int:
147
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
148
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
thomwolf's avatar
thomwolf committed
149

thomwolf's avatar
thomwolf committed
150
        Return: ``torch.nn.Embeddings``
151
            Pointer to the input tokens Embeddings Module of the model
thomwolf's avatar
thomwolf committed
152
153
        """
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
thomwolf's avatar
thomwolf committed
154
155
156
        model_embeds = base_model._resize_token_embeddings(new_num_tokens)
        if new_num_tokens is None:
            return model_embeds
thomwolf's avatar
thomwolf committed
157
158
159
160
161
162
163
164
165

        # Update base model and current model config
        self.config.vocab_size = new_num_tokens
        base_model.vocab_size = new_num_tokens

        # Tie weights again if needed
        if hasattr(self, 'tie_weights'):
            self.tie_weights()

thomwolf's avatar
thomwolf committed
166
167
        return model_embeds

168
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
169
170
171
172
        old_embeddings = self.get_input_embeddings()
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.set_input_embeddings(new_embeddings)
        return self.get_input_embeddings()
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
        """ Build a resized Embedding Module from a provided token Embedding Module.
            Increasing the size will add newly initialized vectors at the end
            Reducing the size will remove vectors from the end

        Args:
            new_num_tokens: (`optional`) int
                New number of tokens in the embedding matrix.
                Increasing the size will add newly initialized vectors at the end
                Reducing the size will remove vectors from the end
                If not provided or None: return the provided token Embedding Module.
        Return: ``torch.nn.Embeddings``
            Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
        """
        if new_num_tokens is None:
            return old_embeddings

        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        if old_num_tokens == new_num_tokens:
            return old_embeddings

        # Build new embeddings
        new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
        new_embeddings.to(old_embeddings.weight.device)

        # initialize all new embeddings (in particular added tokens)
        self._init_weights(new_embeddings)

        # Copy word embeddings from the previous weights
        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
        new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]

        return new_embeddings

208
209
210
211
212
213
214
215
216
    def init_weights(self):
        """ Initialize and prunes weights if needed. """
        # Initialize weights
        self.apply(self._init_weights)

        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

217
218
219
        # Tie weights if needed
        self.tie_weights()

thomwolf's avatar
thomwolf committed
220
221
    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the base model.
222
223
224
225

            Arguments:

                heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
226
                E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
thomwolf's avatar
thomwolf committed
227
        """
228
        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
229
        for layer, heads in heads_to_prune.items():
230
231
232
            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON

233
        self.base_model._prune_heads(heads_to_prune)
thomwolf's avatar
thomwolf committed
234

235
    def save_pretrained(self, save_directory):
236
        """ Save a model and its configuration file to a directory, so that it
237
            can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
238
239
240
        """
        assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"

Julien Chaumond's avatar
Julien Chaumond committed
241
        # Only save the model itself if we are using distributed training
242
243
        model_to_save = self.module if hasattr(self, 'module') else self

thomwolf's avatar
thomwolf committed
244
245
246
        # Save configuration file
        model_to_save.config.save_pretrained(save_directory)

247
248
249
        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
thomwolf's avatar
thomwolf committed
250
        logger.info("Model weights saved in {}".format(output_model_file))
251

252
    @classmethod
253
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
254
255
        r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.

256
257
258
        The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
        To train the model, you should first set it back in training mode with ``model.train()``

259
260
261
262
263
        The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
        It is up to you to train those weights with a downstream fine-tuning task.

        The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.

264
265
266
267
        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
268
                - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
269
                - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
thomwolf's avatar
thomwolf committed
270
                - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
271
272
273
274

            model_args: (`optional`) Sequence of positional arguments:
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method

275
            config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
276
277
278
                Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:

                - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
279
                - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
280
281
282
283
                - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.

            state_dict: (`optional`) dict:
                an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
thomwolf's avatar
typos  
thomwolf committed
284
                This option can be used if you want to create a model from a pretrained configuration but load your own weights.
285
                In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
286
287

            cache_dir: (`optional`) string:
thomwolf's avatar
thomwolf committed
288
289
                Path to a directory in which a downloaded pre-trained model
                configuration should be cached if the standard cache should not be used.
290

291
292
293
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the model weights and configuration files and override the cached versions if they exists.

294
295
296
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

297
298
299
300
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

301
            output_loading_info: (`optional`) boolean:
thomwolf's avatar
thomwolf committed
302
                Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
303
304
305
306
307

            kwargs: (`optional`) Remaining dictionary of keyword arguments:
                Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:

                - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
308
                - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
309
310

        Examples::
thomwolf's avatar
thomwolf committed
311

thomwolf's avatar
thomwolf committed
312
313
314
315
316
317
318
            model = BertModel.from_pretrained('bert-base-uncased')    # Download model and configuration from S3 and cache.
            model = BertModel.from_pretrained('./test/saved_model/')  # E.g. model was saved using `save_pretrained('./test/saved_model/')`
            model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)  # Update configuration during loading
            assert model.config.output_attention == True
            # Loading from a TF checkpoint file instead of a PyTorch model (slower)
            config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
            model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
thomwolf's avatar
thomwolf committed
319

320
        """
321
322
        if pretrained_model_name_or_path is not None and (
                "albert" in pretrained_model_name_or_path and "v2" in pretrained_model_name_or_path):
LysandreJik's avatar
LysandreJik committed
323
324
325
            logger.warning("There is currently an upstream reproducibility issue with ALBERT v2 models. Please see " +
                           "https://github.com/google-research/google-research/issues/119 for more information.")

thomwolf's avatar
thomwolf committed
326
        config = kwargs.pop('config', None)
thomwolf's avatar
thomwolf committed
327
328
        state_dict = kwargs.pop('state_dict', None)
        cache_dir = kwargs.pop('cache_dir', None)
thomwolf's avatar
thomwolf committed
329
        from_tf = kwargs.pop('from_tf', False)
330
        force_download = kwargs.pop('force_download', False)
331
        resume_download = kwargs.pop('resume_download', False)
332
        proxies = kwargs.pop('proxies', None)
thomwolf's avatar
thomwolf committed
333
        output_loading_info = kwargs.pop('output_loading_info', False)
thomwolf's avatar
thomwolf committed
334
335

        # Load config
thomwolf's avatar
thomwolf committed
336
        if config is None:
337
338
            config, model_kwargs = cls.config_class.from_pretrained(
                pretrained_model_name_or_path, *model_args,
339
                cache_dir=cache_dir, return_unused_kwargs=True,
340
                force_download=force_download,
341
                resume_download=resume_download,
342
                proxies=proxies,
343
                **kwargs
344
345
346
            )
        else:
            model_kwargs = kwargs
347

thomwolf's avatar
thomwolf committed
348
        # Load model
thomwolf's avatar
thomwolf committed
349
        if pretrained_model_name_or_path is not None:
350
            if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
351
352
                archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
            elif os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
353
354
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint
thomwolf's avatar
thomwolf committed
355
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
thomwolf's avatar
thomwolf committed
356
357
358
359
360
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
thomwolf's avatar
thomwolf committed
361
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
thomwolf's avatar
thomwolf committed
362
                else:
thomwolf's avatar
thomwolf committed
363
                    raise EnvironmentError("Error no file named {} found in directory {} or `from_tf` set to False".format(
thomwolf's avatar
thomwolf committed
364
                        [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
thomwolf's avatar
thomwolf committed
365
                        pretrained_model_name_or_path))
366
367
            elif os.path.isfile(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
368
369
370
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                assert from_tf, "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
                    pretrained_model_name_or_path + ".index")
371
                archive_file = pretrained_model_name_or_path + ".index"
372
373
            else:
                archive_file = pretrained_model_name_or_path
374

thomwolf's avatar
thomwolf committed
375
376
            # redirect to the cache, if necessary
            try:
377
378
                resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download,
                                                    proxies=proxies, resume_download=resume_download)
thomwolf's avatar
thomwolf committed
379
            except EnvironmentError:
thomwolf's avatar
thomwolf committed
380
                if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
thomwolf's avatar
thomwolf committed
381
382
                    msg = "Couldn't reach server at '{}' to download pretrained weights.".format(
                            archive_file)
thomwolf's avatar
thomwolf committed
383
                else:
thomwolf's avatar
thomwolf committed
384
385
386
                    msg = "Model name '{}' was not found in model name list ({}). " \
                        "We assumed '{}' was a path or url to model weight files named one of {} but " \
                        "couldn't find any such file at this path or url.".format(
thomwolf's avatar
thomwolf committed
387
388
                            pretrained_model_name_or_path,
                            ', '.join(cls.pretrained_model_archive_map.keys()),
thomwolf's avatar
thomwolf committed
389
390
391
392
                            archive_file,
                            [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME])
                raise EnvironmentError(msg)

thomwolf's avatar
thomwolf committed
393
394
            if resolved_archive_file == archive_file:
                logger.info("loading weights file {}".format(archive_file))
395
            else:
thomwolf's avatar
thomwolf committed
396
397
                logger.info("loading weights file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
398
        else:
thomwolf's avatar
thomwolf committed
399
            resolved_archive_file = None
400
401

        # Instantiate model.
402
        model = cls(config, *model_args, **model_kwargs)
thomwolf's avatar
thomwolf committed
403

404
405
        if state_dict is None and not from_tf:
            state_dict = torch.load(resolved_archive_file, map_location='cpu')
406

407
408
409
        missing_keys = []
        unexpected_keys = []
        error_msgs = []
410
411
412
413
414
415
416
417

        if from_tf:
            if resolved_archive_file.endswith('.index'):
                # Load from a TensorFlow 1.X checkpoint - provided by original authors
                model = cls.load_tf_weights(model, config, resolved_archive_file[:-6])  # Remove the '.index'
            else:
                # Load from our TensorFlow 2.0 checkpoints
                try:
418
                    from transformers import load_tf2_checkpoint_in_pytorch_model
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
                    model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
                except ImportError as e:
                    logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
                        "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
                    raise e
        else:
            # Convert old format to new format if needed from a PyTorch state_dict
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
434
435
                if key == 'lm_head.decoder.weight':
                    new_key = 'lm_head.weight'
436
437
438
439
440
441
442
443
444
445
446
447
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata

448
449
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
            # so we need to apply the function recursively.
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')

            # Make sure we are able to load base models as well as derived models (with heads)
            start_prefix = ''
            model_to_load = model
            if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
                start_prefix = cls.base_model_prefix + '.'
            if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
                model_to_load = getattr(model, cls.base_model_prefix)

            load(model_to_load, prefix=start_prefix)
            if len(missing_keys) > 0:
                logger.info("Weights of {} not initialized from pretrained model: {}".format(
                    model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info("Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
            if len(error_msgs) > 0:
                raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                                model.__class__.__name__, "\n\t".join(error_msgs)))
476

thomwolf's avatar
thomwolf committed
477
        if hasattr(model, 'tie_weights'):
478
479
            model.tie_weights()  # make sure word embedding weights are still tied

480
481
482
        # Set model in evaluation mode to desactivate DropOut modules by default
        model.eval()

thomwolf's avatar
thomwolf committed
483
484
485
486
        if output_loading_info:
            loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
            return model, loading_info

487
488
489
        return model


thomwolf's avatar
thomwolf committed
490
491
class Conv1D(nn.Module):
    def __init__(self, nf, nx):
thomwolf's avatar
thomwolf committed
492
        """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
thomwolf's avatar
thomwolf committed
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            Basically works like a Linear layer but the weights are transposed
        """
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


thomwolf's avatar
thomwolf committed
509
510
class PoolerStartLogits(nn.Module):
    """ Compute SQuAD start_logits from sequence hidden states. """
thomwolf's avatar
thomwolf committed
511
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
512
513
514
515
516
        super(PoolerStartLogits, self).__init__()
        self.dense = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, p_mask=None):
        """ Args:
517
518
519
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
                invalid position mask such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
520
        """
thomwolf's avatar
thomwolf committed
521
522
523
        x = self.dense(hidden_states).squeeze(-1)

        if p_mask is not None:
524
525
526
527
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

        return x


class PoolerEndLogits(nn.Module):
    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
    """
    def __init__(self, config):
        super(PoolerEndLogits, self).__init__()
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense_1 = nn.Linear(config.hidden_size, 1)

    def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
        """ Args:
544
545
546
547
548
549
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
550
                position of the first token for the labeled span:
551
552
553
            **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
                Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
                1.0 means token should be masked.
thomwolf's avatar
thomwolf committed
554
555
556
        """
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
557
            slen, hsz = hidden_states.shape[-2:]
thomwolf's avatar
thomwolf committed
558
559
560
561
562
563
564
565
566
567
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)

        x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
        x = self.activation(x)
        x = self.LayerNorm(x)
        x = self.dense_1(x).squeeze(-1)

        if p_mask is not None:
568
569
570
571
            if next(self.parameters()).dtype == torch.float16:
                x = x * (1 - p_mask) - 65500 * p_mask
            else:
                x = x * (1 - p_mask) - 1e30 * p_mask
thomwolf's avatar
thomwolf committed
572
573
574
575
576
577
578
579
580
581
582
583
584

        return x


class PoolerAnswerClass(nn.Module):
    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
    def __init__(self, config):
        super(PoolerAnswerClass, self).__init__()
        self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
        self.activation = nn.Tanh()
        self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        """
        Args:
            One of ``start_states``, ``start_positions`` should be not None.
            If both are set, ``start_positions`` overrides ``start_states``.

            **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
                hidden states of the first tokens for the labeled span.
            **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
                position of the first token for the labeled span.
            **cls_index**: torch.LongTensor of shape ``(batch_size,)``
                position of the CLS token. If None, take the last token.

            note(Original repo):
                no dependency on end_feature so that we can obtain one single `cls_logits`
                for each sample
thomwolf's avatar
thomwolf committed
600
        """
601
        hsz = hidden_states.shape[-1]
thomwolf's avatar
thomwolf committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
        if start_positions is not None:
            start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)

        if cls_index is not None:
            cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
            cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
        else:
            cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)

        x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
        x = self.activation(x)
        x = self.dense_1(x).squeeze(-1)

        return x


class SQuADHead(nn.Module):
621
622
623
    r""" A SQuAD head inspired by XLNet.

    Parameters:
624
        config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

    Inputs:
        **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
            hidden states of sequence tokens
        **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the first token for the labeled span.
        **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
            position of the last token for the labeled span.
        **cls_index**: torch.LongTensor of shape ``(batch_size,)``
            position of the CLS token. If None, take the last token.
        **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
            Whether the question has a possible answer in the paragraph or not.
        **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
            Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
            1.0 means token should be masked.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
thomwolf's avatar
thomwolf committed
644
        **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
645
646
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
            Log probabilities for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
647
        **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
648
649
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
            Indices for the top config.start_n_top start token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
650
        **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
651
652
            ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
653
        **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
654
655
            ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
            Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
thomwolf's avatar
thomwolf committed
656
        **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
657
658
            ``torch.FloatTensor`` of shape ``(batch_size,)``
            Log probabilities for the ``is_impossible`` label of the answers.
thomwolf's avatar
thomwolf committed
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    """
    def __init__(self, config):
        super(SQuADHead, self).__init__()
        self.start_n_top = config.start_n_top
        self.end_n_top = config.end_n_top

        self.start_logits = PoolerStartLogits(config)
        self.end_logits = PoolerEndLogits(config)
        self.answer_class = PoolerAnswerClass(config)

    def forward(self, hidden_states, start_positions=None, end_positions=None,
                cls_index=None, is_impossible=None, p_mask=None):
        outputs = ()

thomwolf's avatar
thomwolf committed
673
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
thomwolf's avatar
thomwolf committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, let's remove the dimension added by batch splitting
            for x in (start_positions, end_positions, cls_index, is_impossible):
                if x is not None and x.dim() > 1:
                    x.squeeze_(-1)

            # during training, compute the end logits based on the ground truth of the start position
            end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)

            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

            if cls_index is not None and is_impossible is not None:
                # Predict answerability from the representation of CLS and START
                cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
                loss_fct_cls = nn.BCEWithLogitsLoss()
                cls_loss = loss_fct_cls(cls_logits, is_impossible)

                # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
                total_loss += cls_loss * 0.5
697
698

            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
699
700
701
702
703
704
705

        else:
            # during inference, compute the end logits based on beam search
            bsz, slen, hsz = hidden_states.size()
            start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)

            start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
706
707
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
thomwolf's avatar
thomwolf committed
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
            end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)

            end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
            end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
            end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)

            start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
            cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)

            outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs

        # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
725
        # or (if labels are provided) (total_loss,)
thomwolf's avatar
thomwolf committed
726
727
728
729
        return outputs


class SequenceSummary(nn.Module):
thomwolf's avatar
thomwolf committed
730
    r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
thomwolf's avatar
thomwolf committed
731
732
733
734
735
        Args of the config class:
            summary_type:
                - 'last' => [default] take the last token hidden state (like XLNet)
                - 'first' => take the first token hidden state (like Bert)
                - 'mean' => take the mean of all tokens hidden states
thomwolf's avatar
thomwolf committed
736
                - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
thomwolf's avatar
thomwolf committed
737
738
                - 'attn' => Not implemented now, use multi-head attention
            summary_use_proj: Add a projection after the vector extraction
739
            summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
740
            summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
741
742
            summary_first_dropout: Add a dropout before the projection and activation
            summary_last_dropout: Add a dropout after the projection and activation
thomwolf's avatar
thomwolf committed
743
744
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
745
746
        super(SequenceSummary, self).__init__()

747
        self.summary_type = config.summary_type if hasattr(config, 'summary_type') else 'last'
thomwolf's avatar
thomwolf committed
748
        if self.summary_type == 'attn':
thomwolf's avatar
thomwolf committed
749
750
751
752
753
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

thomwolf's avatar
thomwolf committed
754
        self.summary = Identity()
thomwolf's avatar
thomwolf committed
755
        if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
756
757
            if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
thomwolf's avatar
thomwolf committed
758
759
760
761
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

thomwolf's avatar
thomwolf committed
762
        self.activation = Identity()
thomwolf's avatar
thomwolf committed
763
764
765
        if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
            self.activation = nn.Tanh()

thomwolf's avatar
thomwolf committed
766
        self.first_dropout = Identity()
767
768
769
        if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

thomwolf's avatar
thomwolf committed
770
        self.last_dropout = Identity()
771
772
        if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)
thomwolf's avatar
thomwolf committed
773

thomwolf's avatar
thomwolf committed
774
    def forward(self, hidden_states, cls_index=None):
775
        """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
thomwolf's avatar
thomwolf committed
776
            cls_index: [optional] position of the classification token if summary_type == 'cls_index',
thomwolf's avatar
thomwolf committed
777
                shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
thomwolf's avatar
thomwolf committed
778
                if summary_type == 'cls_index' and cls_index is None:
thomwolf's avatar
thomwolf committed
779
780
781
782
783
784
785
786
                    we take the last token of the sequence as classification token
        """
        if self.summary_type == 'last':
            output = hidden_states[:, -1]
        elif self.summary_type == 'first':
            output = hidden_states[:, 0]
        elif self.summary_type == 'mean':
            output = hidden_states.mean(dim=1)
thomwolf's avatar
thomwolf committed
787
788
789
        elif self.summary_type == 'cls_index':
            if cls_index is None:
                cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
thomwolf's avatar
thomwolf committed
790
            else:
thomwolf's avatar
thomwolf committed
791
792
793
794
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
thomwolf's avatar
thomwolf committed
795
796
797
        elif self.summary_type == 'attn':
            raise NotImplementedError

798
        output = self.first_dropout(output)
thomwolf's avatar
thomwolf committed
799
800
        output = self.summary(output)
        output = self.activation(output)
801
        output = self.last_dropout(output)
thomwolf's avatar
thomwolf committed
802
803
804
805

        return output


806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def prune_linear_layer(layer, index, dim=0):
    """ Prune a linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def prune_conv1d_layer(layer, index, dim=1):
    """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
        A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if dim == 0:
        b = layer.bias.clone().detach()
    else:
        b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    new_layer.bias.requires_grad = False
    new_layer.bias.copy_(b.contiguous())
    new_layer.bias.requires_grad = True
    return new_layer
853
854
855
856
857
858
859
860
861
862
863
864
865


def prune_layer(layer, index, dim=None):
    """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
        Return the pruned layer as a new layer with requires_grad=True.
        Used to remove heads.
    """
    if isinstance(layer, nn.Linear):
        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
    elif isinstance(layer, Conv1D):
        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
    else:
        raise ValueError("Can't prune layer of class {}".format(layer.__class__))