"examples/run_lm_finetuning.py" did not exist on "936e813c848aa5cad842a18498c440a72505c265"
modeling_openai.py 43.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

20
import collections
thomwolf's avatar
thomwolf committed
21
22
import copy
import json
thomwolf's avatar
thomwolf committed
23
import logging
24
25
import math
import os
thomwolf's avatar
thomwolf committed
26
27
import sys
from io import open
thomwolf's avatar
thomwolf committed
28
29
30

import torch
import torch.nn as nn
thomwolf's avatar
thomwolf committed
31
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
32
33
from torch.nn.parameter import Parameter

34
35
from .file_utils import cached_path
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, prune_conv1d_layer
36
from .modeling import BertLayerNorm as LayerNorm
thomwolf's avatar
thomwolf committed
37

thomwolf's avatar
thomwolf committed
38
39
logger = logging.getLogger(__name__)

40
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
thomwolf's avatar
thomwolf committed
41
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
42

43

44
45
46
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
    """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
    """
47
48
    import re
    import numpy as np
49
50
51
52
53
54
55
56
    print("Loading weights...")
    names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
    shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
    offsets = np.cumsum([np.prod(shape) for shape in shapes])
    init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
    init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
    init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]

thomwolf's avatar
thomwolf committed
57
    # This was used when we had a single embedding matrix for positions and tokens
58
59
    # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
    # del init_params[1]
60
61
62
    init_params = [arr.squeeze() for arr in init_params]

    try:
63
64
        assert model.tokens_embed.weight.shape == init_params[1].shape
        assert model.positions_embed.weight.shape == init_params[0].shape
65
    except AssertionError as e:
66
67
        e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
        e.args += (model.positions_embed.weight.shape, init_params[0].shape)
68
69
        raise

70
71
    model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
    model.positions_embed.weight.data = torch.from_numpy(init_params[0])
72
    names.pop(0)
73
74
    # Pop position and token embedding arrays
    init_params.pop(0)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    init_params.pop(0)

    for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
        name = name[6:]  # skip "model/"
        assert name[-2:] == ":0"
        name = name[:-2]
        name = name.split('/')
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                l = re.split(r'(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'g':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'b':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'w':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)
    return model

thomwolf's avatar
thomwolf committed
113
114
115
116
117
118
119
120
121

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def swish(x):
    return x * torch.sigmoid(x)


122
123
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}

thomwolf's avatar
thomwolf committed
124

125
class OpenAIGPTConfig(PretrainedConfig):
thomwolf's avatar
thomwolf committed
126
127
    """Configuration class to store the configuration of a `OpenAIGPTModel`.
    """
128
    pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
129
130
131
132
133

    def __init__(
        self,
        vocab_size_or_config_json_file=40478,
        n_special=0,
thomwolf's avatar
thomwolf committed
134
        n_positions=512,
135
136
137
138
139
140
141
142
        n_ctx=512,
        n_embd=768,
        n_layer=12,
        n_head=12,
        afn="gelu",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
143
        layer_norm_epsilon=1e-5,
144
        initializer_range=0.02,
145
        predict_special_tokens=True
146
    ):
thomwolf's avatar
thomwolf committed
147
148
149
150
151
        """Constructs OpenAIGPTConfig.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
            n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
thomwolf's avatar
thomwolf committed
152
153
            n_positions: Number of positional embeddings.
            n_ctx: Size of the causal mask (usually same as n_positions).
thomwolf's avatar
thomwolf committed
154
155
156
157
158
159
160
161
162
163
164
            n_embd: Dimensionality of the embeddings and hidden states.
            n_layer: Number of hidden layers in the Transformer encoder.
            n_head: Number of attention heads for each attention layer in
                the Transformer encoder.
            afn: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            resid_pdrop: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attn_pdrop: The dropout ratio for the attention
                probabilities.
            embd_pdrop: The dropout ratio for the embeddings.
165
            layer_norm_epsilon: epsilon to use in the layer norm layers
thomwolf's avatar
thomwolf committed
166
167
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
168
            predict_special_tokens: should we predict special tokens (when the model has a LM head)
thomwolf's avatar
thomwolf committed
169
        """
thomwolf's avatar
thomwolf committed
170
171
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
172
            with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
thomwolf's avatar
thomwolf committed
173
174
175
176
177
178
179
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.n_special = n_special
            self.n_ctx = n_ctx
thomwolf's avatar
thomwolf committed
180
            self.n_positions = n_positions
thomwolf's avatar
thomwolf committed
181
182
183
184
185
186
187
            self.n_embd = n_embd
            self.n_layer = n_layer
            self.n_head = n_head
            self.afn = afn
            self.resid_pdrop = resid_pdrop
            self.embd_pdrop = embd_pdrop
            self.attn_pdrop = attn_pdrop
188
            self.layer_norm_epsilon = layer_norm_epsilon
thomwolf's avatar
thomwolf committed
189
            self.initializer_range = initializer_range
190
            self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
191
        else:
192
193
194
195
            raise ValueError(
                "First argument must be either a vocabulary size (int)"
                "or the path to a pretrained model config file (str)"
            )
thomwolf's avatar
thomwolf committed
196
197

    @property
198
199
    def total_tokens_embeddings(self):
        return self.vocab_size + self.n_special
thomwolf's avatar
thomwolf committed
200

thomwolf's avatar
thomwolf committed
201
202

class Attention(nn.Module):
203
    def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
thomwolf's avatar
thomwolf committed
204
205
206
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
207
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
208
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
209
        self.n_head = config.n_head
thomwolf's avatar
thomwolf committed
210
211
        self.split_size = n_state
        self.scale = scale
212

thomwolf's avatar
thomwolf committed
213
        self.output_attentions = output_attentions
214
215
216
        self.keep_multihead_output = keep_multihead_output
        self.multihead_output = None

217
218
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
219
220
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
221

222
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
223
224
        if len(heads) == 0:
            return
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)

    def _attn(self, q, k, v, head_mask=None):
thomwolf's avatar
thomwolf committed
239
240
241
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
242
        # w = w * self.bias + -1e9 * (1 - self.bias)  # TF implem method: mask_attn_weights
thomwolf's avatar
thomwolf committed
243
        # XD: self.b may be larger than w, so we need to crop it
thomwolf's avatar
thomwolf committed
244
        b = self.bias[:, :, : w.size(-2), : w.size(-1)]
thomwolf's avatar
thomwolf committed
245
246
        w = w * b + -1e9 * (1 - b)

thomwolf's avatar
thomwolf committed
247
248
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
249
250
251
252
253

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
254
255
        if self.output_attentions:
            return w, torch.matmul(w, v)
thomwolf's avatar
thomwolf committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

271
    def forward(self, x, head_mask=None):
thomwolf's avatar
thomwolf committed
272
273
274
275
276
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
277
278
279
280
281
282

        a = self._attn(query, key, value, head_mask)
        if self.keep_multihead_output:
            self.multihead_output = a
            self.multihead_output.retain_grad()

thomwolf's avatar
thomwolf committed
283
284
        if self.output_attentions:
            attentions, a = a
thomwolf's avatar
thomwolf committed
285
286
287
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
288
289
        if self.output_attentions:
            return attentions, a
thomwolf's avatar
thomwolf committed
290
291
292
293
        return a


class MLP(nn.Module):
294
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
thomwolf's avatar
thomwolf committed
295
        super(MLP, self).__init__()
296
        nx = config.n_embd
297
298
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
299
300
        self.act = ACT_FNS[config.afn]
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
301
302
303
304
305
306
307
308

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(nn.Module):
309
    def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
thomwolf's avatar
thomwolf committed
310
        super(Block, self).__init__()
311
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
312
        self.output_attentions = output_attentions
313
        self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
314
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
315
        self.mlp = MLP(4 * nx, config)
316
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
317

318
319
    def forward(self, x, head_mask=None):
        a = self.attn(x, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
320
321
        if self.output_attentions:
            attentions, a = a
thomwolf's avatar
thomwolf committed
322
323
324
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
thomwolf's avatar
thomwolf committed
325
326
        if self.output_attentions:
            return attentions, h
thomwolf's avatar
thomwolf committed
327
328
329
        return h


thomwolf's avatar
thomwolf committed
330
class OpenAIGPTLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
331
332
    """ Language Model Head for the transformer """

333
    def __init__(self, model_embeddings_weights, config):
thomwolf's avatar
thomwolf committed
334
        super(OpenAIGPTLMHead, self).__init__()
335
        self.n_embd = config.n_embd
336
337
        self.vocab_size = config.vocab_size
        self.predict_special_tokens = config.predict_special_tokens
thomwolf's avatar
thomwolf committed
338
339
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
thomwolf's avatar
thomwolf committed
340
341
        self.set_embeddings_weights(model_embeddings_weights)

342
343
    def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
        self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
344
        embed_shape = model_embeddings_weights.shape
345
        self.decoder.weight = model_embeddings_weights  # Tied weights
thomwolf's avatar
thomwolf committed
346

thomwolf's avatar
thomwolf committed
347
348
    def forward(self, hidden_state):
        lm_logits = self.decoder(hidden_state)
349
350
        if not self.predict_special_tokens:
            lm_logits = lm_logits[..., :self.vocab_size]
thomwolf's avatar
thomwolf committed
351
352
353
        return lm_logits


thomwolf's avatar
thomwolf committed
354
class OpenAIGPTMultipleChoiceHead(nn.Module):
thomwolf's avatar
thomwolf committed
355
356
    """ Classifier Head for the transformer """

357
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
358
        super(OpenAIGPTMultipleChoiceHead, self).__init__()
359
360
361
        self.n_embd = config.n_embd
        self.dropout = nn.Dropout2d(config.resid_pdrop)  # To reproduce the noise_shape parameter of TF implementation
        self.linear = nn.Linear(config.n_embd, 1)
thomwolf's avatar
thomwolf committed
362

363
        nn.init.normal_(self.linear.weight, std=0.02)
thomwolf's avatar
thomwolf committed
364
365
        nn.init.normal_(self.linear.bias, 0)

thomwolf's avatar
thomwolf committed
366
    def forward(self, hidden_states, mc_token_ids):
thomwolf's avatar
thomwolf committed
367
        # Classification logits
thomwolf's avatar
thomwolf committed
368
        # hidden_state (bsz, num_choices, seq_length, hidden_size)
thomwolf's avatar
thomwolf committed
369
        # mc_token_ids (bsz, num_choices)
thomwolf's avatar
thomwolf committed
370
        mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
thomwolf's avatar
thomwolf committed
371
372
373
        # (bsz, num_choices, 1, hidden_size)
        multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
        # (bsz, num_choices, hidden_size)
Philipp Glock's avatar
Philipp Glock committed
374
        multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
thomwolf's avatar
thomwolf committed
375
        multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
thomwolf's avatar
thomwolf committed
376
        # (bsz, num_choices)
thomwolf's avatar
thomwolf committed
377
378
379
380
381
382
383
        return multiple_choice_logits


class OpenAIGPTPreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
384

thomwolf's avatar
thomwolf committed
385
386
387
388
389
390
391
392
    def __init__(self, config, *inputs, **kwargs):
        super(OpenAIGPTPreTrainedModel, self).__init__()
        if not isinstance(config, OpenAIGPTConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
393
394
                )
            )
thomwolf's avatar
thomwolf committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        self.config = config

    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
thomwolf's avatar
thomwolf committed
409

thomwolf's avatar
thomwolf committed
410
    @classmethod
411
    def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
thomwolf's avatar
thomwolf committed
412
413
414
415
416
        """
        Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
thomwolf's avatar
thomwolf committed
417
            pretrained_model_name_or_path: either:
thomwolf's avatar
thomwolf committed
418
419
420
421
422
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `openai-gpt`
                - a path or url to a pretrained model archive containing:
                    . `openai_gpt_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
423
                - a path or url to a pretrained model archive containing:
424
                    . `openai-gpt-config.json` a configuration file for the model
425
426
                    . a series of NumPy files containing OpenAI TensorFlow trained weights
            from_tf: should we load the weights from a locally saved TensorFlow checkpoint
thomwolf's avatar
thomwolf committed
427
428
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
429
            *inputs, **kwargs: additional input for the specific OpenAI-GPT class
thomwolf's avatar
thomwolf committed
430
        """
431
432
433
434
435
436
437
        state_dict = kwargs.get('state_dict', None)
        kwargs.pop('state_dict', None)
        cache_dir = kwargs.get('cache_dir', None)
        kwargs.pop('cache_dir', None)
        from_tf = kwargs.get('from_tf', False)
        kwargs.pop('from_tf', None)

thomwolf's avatar
thomwolf committed
438
439
        if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
440
            config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
thomwolf's avatar
thomwolf committed
441
        else:
thomwolf's avatar
thomwolf committed
442
            archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
443
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
thomwolf's avatar
thomwolf committed
444
445
446
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
thomwolf's avatar
thomwolf committed
447
        except EnvironmentError:
thomwolf's avatar
thomwolf committed
448
449
450
451
452
453
454
            if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
                logger.error(
                    "Couldn't reach server at '{}' to download pretrained weights.".format(
                        archive_file))
            else:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
thomwolf's avatar
thomwolf committed
455
                    "We assumed '{}' was a path or url but couldn't find file {} "
thomwolf's avatar
thomwolf committed
456
457
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
thomwolf's avatar
thomwolf committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
                        archive_file
                    )
                )
            return None
        try:
            resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
        except EnvironmentError:
            if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
                logger.error(
                    "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
                        config_file))
            else:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find file {} "
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
                        config_file
thomwolf's avatar
thomwolf committed
476
                    )
477
                )
thomwolf's avatar
thomwolf committed
478
            return None
479
480
481
        if resolved_archive_file == archive_file and resolved_config_file == config_file:
            logger.info("loading weights file {}".format(archive_file))
            logger.info("loading configuration file {}".format(config_file))
thomwolf's avatar
thomwolf committed
482
        else:
483
484
485
486
            logger.info("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
            logger.info("loading configuration file {} from cache at {}".format(
                config_file, resolved_config_file))
thomwolf's avatar
thomwolf committed
487
        # Load config
488
        config = OpenAIGPTConfig.from_json_file(resolved_config_file)
thomwolf's avatar
thomwolf committed
489
490
491
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
492
        if state_dict is None and not from_tf:
thomwolf's avatar
thomwolf committed
493
            state_dict = torch.load(resolved_archive_file, map_location='cpu')
494
495
        if from_tf:
            # Directly load from a TensorFlow checkpoint (stored as NumPy array)
496
            return load_tf_weights_in_openai_gpt(model, resolved_archive_file)
thomwolf's avatar
thomwolf committed
497
498
499
500
501

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
thomwolf's avatar
thomwolf committed
502
503
504
505
506
507
            if key.endswith(".g"):
                new_key = key[:-2] + ".weight"
            elif key.endswith(".b"):
                new_key = key[:-2] + ".bias"
            elif key.endswith(".w"):
                new_key = key[:-2] + ".weight"
thomwolf's avatar
thomwolf committed
508
509
510
511
512
513
514
515
516
517
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
518
        metadata = getattr(state_dict, "_metadata", None)
thomwolf's avatar
thomwolf committed
519
520
521
522
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

523
        def load(module, prefix=""):
thomwolf's avatar
thomwolf committed
524
525
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
526
527
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
            )
thomwolf's avatar
thomwolf committed
528
529
            for name, child in module._modules.items():
                if child is not None:
530
531
                    load(child, prefix + name + ".")

thomwolf's avatar
thomwolf committed
532
533
        start_model = model
        if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
thomwolf's avatar
update  
thomwolf committed
534
535
536
            start_model = model.transformer
        load(start_model, prefix="")

thomwolf's avatar
thomwolf committed
537
        if len(missing_keys) > 0:
538
539
540
            logger.info(
                "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
thomwolf's avatar
thomwolf committed
541
        if len(unexpected_keys) > 0:
542
543
544
            logger.info(
                "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
thomwolf's avatar
thomwolf committed
545
        if len(error_msgs) > 0:
546
547
548
            raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )
549

thomwolf's avatar
thomwolf committed
550
        # Add additional embeddings for special tokens if needed
551
552
        # This step also make sure we are still sharing the output and input embeddings after loading weights
        model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
thomwolf's avatar
thomwolf committed
553
        return model
thomwolf's avatar
thomwolf committed
554
555


thomwolf's avatar
thomwolf committed
556
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
557
558
    """OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").

559
560
561
562
563
564
    OpenAI GPT use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
    The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

    The embeddings are ordered as follow in the token embeddings matrice:
565
566
567
568
569
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
570
         config.vocab_size + config.n_special - 1]                  ______________________
571

572
573
    where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
        total_tokens_embeddings = config.vocab_size + config.n_special
574
575
576
    You should use the associate indices to index the embeddings.

    Params:
577
578
579
580
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
581
582
583

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
584
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
585
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
586
            with the position indices (selected in the range [0, config.n_positions - 1[.
587
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
588
589
590
591
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
592
593
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
594
595

    Outputs:
596
597
        `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings)
            as torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
598
599
600
601
602
603
604
605
606
607
608
609
610
            (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_openai.OpenAIGPTConfig()

    model = modeling_openai.OpenAIGPTModel(config)
    hidden_states = model(input_ids)
    ```
    """
611

612
    def __init__(self, config, output_attentions=False, keep_multihead_output=False):
613
        super(OpenAIGPTModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
614
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
615
        self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
616
        self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
617
        self.drop = nn.Dropout(config.embd_pdrop)
618
619
        block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
                                                        keep_multihead_output=keep_multihead_output)
620
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
621

thomwolf's avatar
thomwolf committed
622
623
624
        self.apply(self.init_weights)

    def set_num_special_tokens(self, num_special_tokens):
625
626
627
        " Update input embeddings with new embedding matrice if needed "
        if self.config.n_special == num_special_tokens:
            return
thomwolf's avatar
thomwolf committed
628
629
        # Update config
        self.config.n_special = num_special_tokens
thomwolf's avatar
thomwolf committed
630
        # Build new embeddings and initialize all new embeddings (in particular the special tokens)
631
        old_embed = self.tokens_embed
632
        self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
thomwolf's avatar
thomwolf committed
633
        self.tokens_embed.to(old_embed.weight.device)
634
        self.init_weights(self.tokens_embed)
thomwolf's avatar
thomwolf committed
635
636
        # Copy word embeddings from the previous weights
        self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
thomwolf's avatar
thomwolf committed
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    def get_multihead_outputs(self):
        """ Gather all multi-head outputs.
            Return: list (layers) of multihead module outputs with gradients
        """
        return [h.attn.multihead_output for h in self.h]

    def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
thomwolf's avatar
thomwolf committed
652
        if position_ids is None:
653
654
655
656
657
            # This was used when we had a single embedding matrice from position and token embeddings
            # start = self.config.vocab_size + self.config.n_special
            # end = start + input_ids.size(-1)
            # position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
            position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
thomwolf's avatar
thomwolf committed
658
659
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

660
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
661
        # 1.0 in head_mask indicate we keep the head
662
        # attention_probs has shape bsz x n_heads x N x N
663
        # head_mask has shape n_layer x batch x n_heads x N x N
664
665
        if head_mask is not None:
            if head_mask.dim() == 1:
666
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
667
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
668
            elif head_mask.dim() == 2:
669
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
670
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
671
672
        else:
            head_mask = [None] * self.config.n_layer
673

thomwolf's avatar
thomwolf committed
674
675
676
677
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

678
679
        inputs_embeds = self.tokens_embed(input_ids)
        position_embeds = self.positions_embed(position_ids)
thomwolf's avatar
thomwolf committed
680
681
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
682
            token_type_embeds = self.tokens_embed(token_type_ids)
thomwolf's avatar
thomwolf committed
683
684
685
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
686
687
        hidden_states = self.drop(hidden_states)

688
689
        output_shape = input_shape + (hidden_states.size(-1),)

thomwolf's avatar
thomwolf committed
690
        all_attentions = []
691
        all_hidden_states = [hidden_states.view(*output_shape)]
692
693
        for i, block in enumerate(self.h):
            outputs = block(hidden_states, head_mask[i])
thomwolf's avatar
thomwolf committed
694
            if self.output_attentions:
695
                attentions, hidden_states = outputs
thomwolf's avatar
thomwolf committed
696
697
                all_attentions.append(attentions)
            else:
698
                hidden_states = outputs
699
700
            all_hidden_states.append(hidden_states.view(*output_shape))

thomwolf's avatar
thomwolf committed
701
        if self.output_attentions:
702
703
            return all_attentions, all_hidden_states
        return all_hidden_states
thomwolf's avatar
thomwolf committed
704

705

thomwolf's avatar
thomwolf committed
706
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
707
708
    """OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").

709
710
711
712
713
714
    OpenAI GPT use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
    The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

    The embeddings are ordered as follow in the token embeddings matrice:
715
716
717
718
719
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
720
         config.vocab_size + config.n_special - 1]                  ______________________
721

722
723
724
    where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
        total_tokens_embeddings = config.vocab_size + config.n_special
    You should use the associate indices to index the embeddings.
725
726

    Params:
727
728
729
730
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
731
732
733

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
734
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
735
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
736
            with the position indices (selected in the range [0, config.n_positions - 1[.
737
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
738
739
740
741
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
742
743
744
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
            with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., vocab_size]
745
746
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
747
748
749
750
751

    Outputs:
        if `lm_labels` is not `None`:
            Outputs the language modeling loss.
        else:
752
753
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings]
                (or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
754
755
756
757
758
759
760
761
762
763
764
765

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_openai.OpenAIGPTConfig()

    model = modeling_openai.OpenAIGPTLMHeadModel(config)
    lm_logits = model(input_ids)
    ```
    """
766

767
    def __init__(self, config, output_attentions=False, keep_multihead_output=False):
768
        super(OpenAIGPTLMHeadModel, self).__init__(config)
769
770
        self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
                                             keep_multihead_output=keep_multihead_output)
771
        self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
thomwolf's avatar
thomwolf committed
772
773
        self.apply(self.init_weights)

774
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
775
776
777
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
        """
778
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
779
        self.transformer.set_num_special_tokens(num_special_tokens)
780
        self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
781

782
783
    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
        hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
thomwolf's avatar
thomwolf committed
784
785
        if self.transformer.output_attentions:
            all_attentions, hidden_states = hidden_states
786
787
        hidden_states = hidden_states[-1]

thomwolf's avatar
thomwolf committed
788
789
        lm_logits = self.lm_head(hidden_states)
        if lm_labels is not None:
790
            # Shift so that tokens < n predict n
thomwolf's avatar
thomwolf committed
791
792
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
793
            # Flatten the tokens
thomwolf's avatar
thomwolf committed
794
            loss_fct = CrossEntropyLoss(ignore_index=-1)
795
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
796
                            shift_labels.view(-1))
thomwolf's avatar
thomwolf committed
797
            return loss
thomwolf's avatar
thomwolf committed
798
799
        if self.transformer.output_attentions:
            return all_attentions, lm_logits
thomwolf's avatar
thomwolf committed
800
        return lm_logits
thomwolf's avatar
thomwolf committed
801

802

thomwolf's avatar
thomwolf committed
803
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
thomwolf's avatar
thomwolf committed
804
    """OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training").
805

806
807
808
809
810
811
    OpenAI GPT use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
    The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

    The embeddings are ordered as follow in the token embeddings matrice:
812
813
814
815
816
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
817
         config.vocab_size + config.n_special - 1]                  ______________________
818

819
820
821
    where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
        total_tokens_embeddings = config.vocab_size + config.n_special
    You should use the associate indices to index the embeddings.
822
823

    Params:
824
825
826
827
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
828
829

    Inputs:
thomwolf's avatar
thomwolf committed
830
831
832
833
        `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
            indices selected in the range [0, total_tokens_embeddings[
        `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
            which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
834
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
835
            with the position indices (selected in the range [0, config.n_positions - 1[.
836
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
837
838
839
840
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
841
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
842
843
            with indices selected in [-1, 0, ..., total_tokens_embeddings]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., total_tokens_embeddings]
844
845
        `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_choices].
846
847
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
848
849
850
851
852

    Outputs:
        if `lm_labels` and `multiple_choice_labels` are not `None`:
            Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
        else: a tuple with
853
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
854
855
856
857
858
            `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]

    Example usage:
    ```python
    # Already been converted into BPE token ids
thomwolf's avatar
thomwolf committed
859
860
    input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]])  # (bsz, number of choice, seq length)
    mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
861
862
863

    config = modeling_openai.OpenAIGPTConfig()

VictorSanh's avatar
VictorSanh committed
864
    model = modeling_openai.OpenAIGPTDoubleHeadsModel(config)
thomwolf's avatar
thomwolf committed
865
    lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
866
867
    ```
    """
868

869
    def __init__(self, config, output_attentions=False, keep_multihead_output=False):
870
        super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
871
872
        self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
                                             keep_multihead_output=keep_multihead_output)
873
        self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
874
        self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
thomwolf's avatar
thomwolf committed
875
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
876

877
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
878
879
880
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
        """
881
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
882
        self.transformer.set_num_special_tokens(num_special_tokens)
883
        self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
884

885
886
887
    def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
                position_ids=None, head_mask=None):
        hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
thomwolf's avatar
thomwolf committed
888
889
        if self.transformer.output_attentions:
            all_attentions, hidden_states = hidden_states
890
891
        hidden_states = hidden_states[-1]

thomwolf's avatar
thomwolf committed
892
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
893
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
thomwolf's avatar
thomwolf committed
894
895
        losses = []
        if lm_labels is not None:
thomwolf's avatar
thomwolf committed
896
897
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
thomwolf's avatar
thomwolf committed
898
            loss_fct = CrossEntropyLoss(ignore_index=-1)
thomwolf's avatar
thomwolf committed
899
            losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
900
        if mc_labels is not None:
thomwolf's avatar
thomwolf committed
901
            loss_fct = CrossEntropyLoss()
902
            losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
thomwolf's avatar
thomwolf committed
903
904
        if losses:
            return losses
thomwolf's avatar
thomwolf committed
905
906
        if self.transformer.output_attentions:
            return all_attentions, lm_logits, mc_logits
907
        return lm_logits, mc_logits