modeling_tf_gpt2.py 29.4 KB
Newer Older
thomwolf's avatar
WIP  
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """

from __future__ import absolute_import, division, print_function, unicode_literals

import collections
import json
import logging
import math
import os
import sys
from io import open

import numpy as np
import tensorflow as tf

thomwolf's avatar
thomwolf committed
31
32
from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
                                TFSequenceSummary, shape_list)
thomwolf's avatar
WIP  
thomwolf committed
33
34
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings
35
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
WIP  
thomwolf committed
36
37
38

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
39
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5",
thomwolf's avatar
WIP  
thomwolf committed
40
41
42
43
                                     "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5",
                                     "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}


44
45
def load_gpt2_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
    # build the network
thomwolf's avatar
WIP  
thomwolf committed
46
47
    inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
    tf_inputs = tf.constant(inputs_list)
48
    tfo = tf_model(tf_inputs, training=False)
49
    return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
thomwolf's avatar
WIP  
thomwolf committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66


def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
    cdf = 0.5 * (1.0 + tf.tanh(
        (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf


class TFAttention(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
67
68
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
        super(TFAttention, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
69
70
71
72
73
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
74
        self.n_ctx = n_ctx
thomwolf's avatar
WIP  
thomwolf committed
75
76
77
78
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

thomwolf's avatar
thomwolf committed
79
80
        self.c_attn = TFConv1D(n_state * 3, nx, name='c_attn')
        self.c_proj = TFConv1D(n_state, nx, name='c_proj')
thomwolf's avatar
thomwolf committed
81
82
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
83
84
85
        self.pruned_heads = set()

    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
86
87
88
        pass

    @staticmethod
thomwolf's avatar
thomwolf committed
89
    def causal_attention_mask(nd, ns, dtype):
thomwolf's avatar
thomwolf committed
90
91
92
93
94
95
96
97
98
        """1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        i = tf.range(nd)[:,None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def _attn(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
99
        q, k, v, attention_mask, head_mask = inputs
thomwolf's avatar
thomwolf committed
100
101
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
thomwolf's avatar
WIP  
thomwolf committed
102
        if self.scale:
thomwolf's avatar
thomwolf committed
103
104
            dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
            w = w / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
105
106
107

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
thomwolf's avatar
thomwolf committed
108
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
thomwolf's avatar
thomwolf committed
109
        b = tf.reshape(b, [1, 1, nd, ns])
thomwolf's avatar
WIP  
thomwolf committed
110
111
        w = w * b - 1e4 * (1 - b)

thomwolf's avatar
thomwolf committed
112
113
114
115
116
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
thomwolf's avatar
thomwolf committed
117
        w = self.attn_dropout(w, training=training)
thomwolf's avatar
WIP  
thomwolf committed
118
119
120
121
122

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
123
        outputs = [tf.matmul(w, v)]
thomwolf's avatar
WIP  
thomwolf committed
124
125
126
127
128
        if self.output_attentions:
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
thomwolf's avatar
thomwolf committed
129
        x = tf.transpose(x, [0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
130
131
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
thomwolf's avatar
thomwolf committed
132
133
134
        return tf.reshape(x, new_x_shape)

    def split_heads(self, x):
thomwolf's avatar
thomwolf committed
135
136
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
thomwolf's avatar
thomwolf committed
137
138
139
140
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
141
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
142
143

        x = self.c_attn(x)
thomwolf's avatar
thomwolf committed
144
        query, key, value = tf.split(x, 3, axis=2)
thomwolf's avatar
WIP  
thomwolf committed
145
        query = self.split_heads(query)
thomwolf's avatar
thomwolf committed
146
        key = self.split_heads(key)
thomwolf's avatar
WIP  
thomwolf committed
147
148
        value = self.split_heads(value)
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
149
150
151
152
            past_key, past_value = tf.unstack(layer_past, axis=1)
            key = tf.concat([past_key, key], axis=-2)
            value = tf.concat([past_value, value], axis=-2)
        present = tf.stack([key, value], axis=1)
thomwolf's avatar
WIP  
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
        attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
thomwolf's avatar
WIP  
thomwolf committed
155
156
157
158
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
thomwolf's avatar
thomwolf committed
159
        a = self.resid_dropout(a, training=training)
thomwolf's avatar
WIP  
thomwolf committed
160
161
162
163
164

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)


thomwolf's avatar
thomwolf committed
165
class TFMLP(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
166
167
    def __init__(self, n_state, config, **kwargs):
        super(TFMLP, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
168
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
169
170
        self.c_fc = TFConv1D(n_state, nx, name='c_fc')
        self.c_proj = TFConv1D(nx, n_state, name='c_proj')
thomwolf's avatar
WIP  
thomwolf committed
171
        self.act = gelu
thomwolf's avatar
thomwolf committed
172
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
173

thomwolf's avatar
thomwolf committed
174
    def call(self, x, training=False):
thomwolf's avatar
WIP  
thomwolf committed
175
176
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
thomwolf's avatar
thomwolf committed
177
        h2 = self.dropout(h2, training=training)
thomwolf's avatar
thomwolf committed
178
        return h2
thomwolf's avatar
WIP  
thomwolf committed
179
180


thomwolf's avatar
thomwolf committed
181
182
183
class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
        super(TFBlock, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
184
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
185
186
187
188
        self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_1')
        self.attn = TFAttention(nx, n_ctx, config, scale, name='attn')
        self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
        self.mlp = TFMLP(4 * nx, config, name='mlp')
thomwolf's avatar
WIP  
thomwolf committed
189

thomwolf's avatar
thomwolf committed
190
191
    def call(self, inputs, training=False):
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
192

thomwolf's avatar
thomwolf committed
193
194
195
        a = self.ln_1(x)
        output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
        a = output_attn[0]  # output_attn: a, present, (attentions)
thomwolf's avatar
WIP  
thomwolf committed
196
        x = x + a
thomwolf's avatar
thomwolf committed
197
198
199

        m = self.ln_2(x)
        m = self.mlp(m, training=training)
thomwolf's avatar
WIP  
thomwolf committed
200
201
202
203
204
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)

thomwolf's avatar
thomwolf committed
205
206
207
208
209
210

class TFGPT2MainLayer(tf.keras.layers.Layer):
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2MainLayer, self).__init__(config, *inputs, **kwargs)
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
thomwolf's avatar
thomwolf committed
211
        self.num_hidden_layers = config.n_layer
thomwolf's avatar
thomwolf committed
212
213
214
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd

thomwolf's avatar
thomwolf committed
215
        self.wte = TFSharedEmbeddings(config.vocab_size, config.hidden_size, name='wte')
thomwolf's avatar
thomwolf committed
216
217
        self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
        self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
thomwolf's avatar
thomwolf committed
218
219
220
        self.h = [TFBlock(config.n_ctx,
                          config,
                          scale=True,
221
                          name='h_._{}'.format(i)) for i in range(config.n_layer)]
thomwolf's avatar
thomwolf committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        raise NotImplementedError

    def call(self, inputs, training=False):
        if not isinstance(inputs, (dict, tuple, list)):
            input_ids = inputs
thomwolf's avatar
thomwolf committed
236
            past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
thomwolf's avatar
thomwolf committed
237
238
        elif isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
239
240
241
242
243
244
            past = inputs[1] if len(inputs) > 1 else None
            attention_mask = inputs[2] if len(inputs) > 2 else None
            token_type_ids = inputs[3] if len(inputs) > 3 else None
            position_ids = inputs[4] if len(inputs) > 4 else None
            head_mask = inputs[5] if len(inputs) > 5 else None
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
245
246
        else:
            input_ids = inputs.get('input_ids')
thomwolf's avatar
thomwolf committed
247
            past = inputs.get('past', None)
thomwolf's avatar
thomwolf committed
248
249
250
251
            attention_mask = inputs.get('attention_mask', None)
            token_type_ids = inputs.get('token_type_ids', None)
            position_ids = inputs.get('position_ids', None)
            head_mask = inputs.get('head_mask', None)
thomwolf's avatar
thomwolf committed
252
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
253
254
255
256
257

        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
thomwolf's avatar
thomwolf committed
258
            past_length = shape_list(past[0][0])[-2]
thomwolf's avatar
thomwolf committed
259
        if position_ids is None:
thomwolf's avatar
thomwolf committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
            position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
thomwolf's avatar
thomwolf committed
280
281
282
283

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
284
285
286
287
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if not head_mask is None:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
288
        else:
thomwolf's avatar
thomwolf committed
289
290
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)
thomwolf's avatar
thomwolf committed
291

thomwolf's avatar
thomwolf committed
292
293
294
        input_shape = shape_list(input_ids)
        input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
thomwolf's avatar
thomwolf committed
295

thomwolf's avatar
thomwolf committed
296
        inputs_embeds = self.wte(input_ids, mode='embedding')
thomwolf's avatar
thomwolf committed
297
298
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
thomwolf's avatar
thomwolf committed
299
300
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
            token_type_embeds = self.wte(token_type_ids, mode='embedding')
thomwolf's avatar
thomwolf committed
301
302
303
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
304
        hidden_states = self.drop(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
305

thomwolf's avatar
thomwolf committed
306
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
thomwolf's avatar
thomwolf committed
307
308
309
310
311
312

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
313
314
315
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

            outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
thomwolf's avatar
thomwolf committed
316
317
318

            hidden_states, present = outputs[:2]
            presents = presents + (present,)
thomwolf's avatar
WIP  
thomwolf committed
319

thomwolf's avatar
thomwolf committed
320
321
322
323
324
            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

thomwolf's avatar
thomwolf committed
325
        hidden_states = tf.reshape(hidden_states, output_shape)
thomwolf's avatar
thomwolf committed
326
327
328
329
330
331
332
333
334
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states, presents)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
thomwolf's avatar
thomwolf committed
335
336
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
thomwolf's avatar
thomwolf committed
337
338
339
            outputs = outputs + (all_attentions,)
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
340

thomwolf's avatar
thomwolf committed
341
class TFGPT2PreTrainedModel(TFPreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
342
343
344
345
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = GPT2Config
thomwolf's avatar
thomwolf committed
346
    pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
347
    load_pt_weights = load_gpt2_pt_weights_in_tf2
thomwolf's avatar
WIP  
thomwolf committed
348
349
350
351
352
353
354
355
356
    base_model_prefix = "transformer"


GPT2_START_DOCSTRING = r"""    OpenAI GPT-2 model was proposed in
    `Language Models are Unsupervised Multitask Learners`_
    by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
    It's a causal (unidirectional) transformer pre-trained using  language modeling on a very large
    corpus of ~40 GB of text data.

thomwolf's avatar
thomwolf committed
357
358
    This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.
thomwolf's avatar
WIP  
thomwolf committed
359
360
361
362

    .. _`Language Models are Unsupervised Multitask Learners`:
        https://openai.com/blog/better-language-models/

thomwolf's avatar
thomwolf committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    .. _`tf.keras.Model`:
        https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model

    Important note on the model inputs:
        The inputs of the TF 2.0 models are slightly different from the PyTorch ones since
        TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor.
        More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`.
        There are three possibilities to gather and feed the inputs to the model:

        - a single Tensor with input_ids only and nothing else: `model(inputs_ids)
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
            `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
            `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
WIP  
thomwolf committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

    Parameters:
        config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

GPT2_INPUTS_DOCSTRING = r"""    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.
            Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **past**:
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `past` output below). Can be used to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
396
397
398
399
400
401
402
403
404
405
406
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
WIP  
thomwolf committed
407
408
409
410
411
412
413
414
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
                      GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
415
class TFGPT2Model(TFGPT2PreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
435
        model = GPT2Model.from_pretrained('gpt2')
thomwolf's avatar
WIP  
thomwolf committed
436
437
438
439
440
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
thomwolf's avatar
thomwolf committed
441
442
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
443
        self.transformer = TFGPT2MainLayer(config, name='transformer')
thomwolf's avatar
thomwolf committed
444
445

    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
446
447
        outputs = self.transformer(inputs, training=training)
        return outputs
thomwolf's avatar
WIP  
thomwolf committed
448
449
450
451


@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
452
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import torch
        from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
thomwolf's avatar
thomwolf committed
478
479
        outputs = model(input_ids)
        logits = outputs[:2]
thomwolf's avatar
WIP  
thomwolf committed
480
481

    """
thomwolf's avatar
thomwolf committed
482
483
484
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name='transformer')
thomwolf's avatar
WIP  
thomwolf committed
485

thomwolf's avatar
thomwolf committed
486
487
    def call(self, inputs, training=False):
        transformer_outputs = self.transformer(inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
488
489
        hidden_states = transformer_outputs[0]

thomwolf's avatar
thomwolf committed
490
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
thomwolf's avatar
WIP  
thomwolf committed
491
492
493

        outputs = (lm_logits,) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
494
        return outputs  # lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
WIP  
thomwolf committed
495
496
497
498
499
500


@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the input sequence).
thomwolf's avatar
thomwolf committed
501
502
503
504
""", GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
    r"""
        **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
thomwolf's avatar
WIP  
thomwolf committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
            Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import torch
        from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
        
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
        
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
        
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_prediction_scores, mc_prediction_scores = outputs[:2]

    """
thomwolf's avatar
thomwolf committed
549
550
551
552
553
554
555
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2DoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name='transformer')
        self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')

    def call(self, inputs, training=False):
        if not isinstance(inputs, (dict, tuple, list)):
556
557
            input_ids = inputs
            mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
thomwolf's avatar
thomwolf committed
558
559
        elif isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
560
            mc_token_ids = inputs[1] if len(inputs) > 1 else None
thomwolf's avatar
thomwolf committed
561
562
563
564
565
566
567
568
            past = inputs[2] if len(inputs) > 2 else None
            attention_mask = inputs[3] if len(inputs) > 3 else None
            token_type_ids = inputs[4] if len(inputs) > 4 else None
            position_ids = inputs[5] if len(inputs) > 5 else None
            head_mask = inputs[6] if len(inputs) > 6 else None
            assert len(inputs) <= 7, "Too many inputs."
        else:
            input_ids = inputs.get('input_ids')
569
            mc_token_ids = inputs.get('mc_token_ids', None)
thomwolf's avatar
thomwolf committed
570
571
572
573
574
            past = inputs.get('past', None)
            attention_mask = inputs.get('attention_mask', None)
            token_type_ids = inputs.get('token_type_ids', None)
            position_ids = inputs.get('position_ids', None)
            head_mask = inputs.get('head_mask', None)
thomwolf's avatar
thomwolf committed
575
            assert len(inputs) <= 7, "Too many inputs."
thomwolf's avatar
thomwolf committed
576

577
578
579
        input_shapes = shape_list(input_ids)

        seq_length = input_shapes[-1]
thomwolf's avatar
thomwolf committed
580
581
582
583
584
585
586
587

        flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

        flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]

588
        transformer_outputs = self.transformer(flat_inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
589
590
        hidden_states = transformer_outputs[0]

591
592
        hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])

thomwolf's avatar
thomwolf committed
593
594
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
        mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
thomwolf's avatar
WIP  
thomwolf committed
595

596
597
        mc_logits = tf.squeeze(mc_logits, axis=-1)

thomwolf's avatar
WIP  
thomwolf committed
598
599
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
600
        return outputs  # lm logits, mc logits, presents, (all hidden_states), (attentions)