"tools/vscode:/vscode.git/clone" did not exist on "afe88104ac750a5c865c42941c3b708a36c6ae8f"
modeling_tf_gpt2.py 34 KB
Newer Older
thomwolf's avatar
WIP  
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """

from __future__ import absolute_import, division, print_function, unicode_literals

import collections
import json
import logging
import math
import os
import sys
from io import open

import numpy as np
import tensorflow as tf

thomwolf's avatar
thomwolf committed
31
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list
thomwolf's avatar
WIP  
thomwolf committed
32
33
34
35
36
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
37
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5",
thomwolf's avatar
WIP  
thomwolf committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                                     "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5",
                                     "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"}


def load_gpt2_pt_weights_in_tf(tf_model, config, pytorch_checkpoint_path):
    """ Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
        We use HDF5 to easily do transfer learning
        (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
    """
    try:
        import re
        import torch
        import numpy
        from tensorflow.python.keras import backend as K
    except ImportError:
        logger.error("Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
            "https://pytorch.org/ for installation instructions.")
        raise

    pt_path = os.path.abspath(pytorch_checkpoint_path)
    logger.info("Loading PyTorch weights from {}".format(pt_path))
    # Load pytorch model
    state_dict = torch.load(pt_path, map_location='cpu')

    inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
    tf_inputs = tf.constant(inputs_list)
    tfo = tf_model(tf_inputs, training=False)  # build the network

    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
    weight_value_tuples = []
    for symbolic_weight in symbolic_weights:
        name = symbolic_weight.name
        name = name.replace('cls_mlm', 'cls')  # We had to split this layer in two in the TF model to be
        name = name.replace('cls_nsp', 'cls')  # able to do transfer learning (Keras only allow to remove full layers)
        name = name.replace(':0', '')
        name = name.replace('layer_', 'layer/')
        name = name.split('/')
        name = name[1:]

        transpose = bool(name[-1] == 'kernel')
        if name[-1] == 'kernel' or name[-1] == 'embeddings':
            name[-1] = 'weight'

        name = '.'.join(name)
        assert name in state_dict
        array = state_dict[name].numpy()

        if transpose:
            array = numpy.transpose(array)

        try:
            assert list(symbolic_weight.shape) == list(array.shape)
        except AssertionError as e:
            e.args += (symbolic_weight.shape, array.shape)
            raise e

        logger.info("Initialize TF weight {}".format(symbolic_weight.name))

        weight_value_tuples.append((symbolic_weight, array))

    K.batch_set_value(weight_value_tuples)

    tfo = tf_model(tf_inputs, training=False)  # Make sure restore ops are run
    return tf_model


def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
    cdf = 0.5 * (1.0 + tf.tanh(
        (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf


class TFAttention(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
119
120
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
        super(TFAttention, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
121
122
123
124
125
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
126
        self.n_ctx = n_ctx
thomwolf's avatar
WIP  
thomwolf committed
127
128
129
130
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

thomwolf's avatar
thomwolf committed
131
132
        self.c_attn = TFConv1D(n_state * 3, nx, name='c_attn')
        self.c_proj = TFConv1D(n_state, nx, name='c_proj')
thomwolf's avatar
thomwolf committed
133
134
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
135
136
137
        self.pruned_heads = set()

    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
138
139
140
141
        pass

    @staticmethod
    @tf.function
thomwolf's avatar
thomwolf committed
142
    def causal_attention_mask(nd, ns, dtype):
thomwolf's avatar
thomwolf committed
143
144
145
146
147
148
149
150
151
152
        """1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        i = tf.range(nd)[:,None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    @tf.function
    def _attn(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
153
        q, k, v, attention_mask, head_mask = inputs
thomwolf's avatar
thomwolf committed
154
155
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
thomwolf's avatar
WIP  
thomwolf committed
156
        if self.scale:
thomwolf's avatar
thomwolf committed
157
158
            dk = tf.cast(tf.shape(k)[-1], tf.float32) # scale attention_scores
            w = w / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
159
160
161

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
thomwolf's avatar
thomwolf committed
162
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
thomwolf's avatar
thomwolf committed
163
        b = tf.reshape(b, [1, 1, nd, ns])
thomwolf's avatar
WIP  
thomwolf committed
164
165
        w = w * b - 1e4 * (1 - b)

thomwolf's avatar
thomwolf committed
166
167
168
169
170
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
thomwolf's avatar
thomwolf committed
171
172
        if training:
            w = self.attn_dropout(w)
thomwolf's avatar
WIP  
thomwolf committed
173
174
175
176
177

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
178
        outputs = [tf.matmul(w, v)]
thomwolf's avatar
WIP  
thomwolf committed
179
180
181
182
        if self.output_attentions:
            outputs.append(w)
        return outputs

thomwolf's avatar
thomwolf committed
183
    @tf.function
thomwolf's avatar
WIP  
thomwolf committed
184
    def merge_heads(self, x):
thomwolf's avatar
thomwolf committed
185
        x = tf.transpose(x, [0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
186
187
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
thomwolf's avatar
thomwolf committed
188
189
190
191
        return tf.reshape(x, new_x_shape)

    @tf.function
    def split_heads(self, x):
thomwolf's avatar
thomwolf committed
192
193
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
thomwolf's avatar
thomwolf committed
194
195
196
197
198
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

    @tf.function
    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
199
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
200
201

        x = self.c_attn(x)
thomwolf's avatar
thomwolf committed
202
        query, key, value = tf.split(x, 3, axis=2)
thomwolf's avatar
WIP  
thomwolf committed
203
        query = self.split_heads(query)
thomwolf's avatar
thomwolf committed
204
        key = self.split_heads(key)
thomwolf's avatar
WIP  
thomwolf committed
205
206
        value = self.split_heads(value)
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
207
208
209
210
            past_key, past_value = tf.unstack(layer_past, axis=1)
            key = tf.concat([past_key, key], axis=-2)
            value = tf.concat([past_value, value], axis=-2)
        present = tf.stack([key, value], axis=1)
thomwolf's avatar
WIP  
thomwolf committed
211

thomwolf's avatar
thomwolf committed
212
        attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
thomwolf's avatar
WIP  
thomwolf committed
213
214
215
216
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
thomwolf's avatar
thomwolf committed
217
218
        if training:
            a = self.resid_dropout(a)
thomwolf's avatar
WIP  
thomwolf committed
219
220
221
222
223

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)


thomwolf's avatar
thomwolf committed
224
class TFMLP(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
225
226
    def __init__(self, n_state, config, **kwargs):
        super(TFMLP, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
227
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
228
229
        self.c_fc = TFConv1D(n_state, nx, name='c_fc')
        self.c_proj = TFConv1D(nx, n_state, name='c_proj')
thomwolf's avatar
WIP  
thomwolf committed
230
        self.act = gelu
thomwolf's avatar
thomwolf committed
231
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
232

thomwolf's avatar
thomwolf committed
233
234
    @tf.function
    def call(self, x, training=False):
thomwolf's avatar
WIP  
thomwolf committed
235
236
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
thomwolf's avatar
thomwolf committed
237
238
239
        if training:
            h2 = self.dropout(h2)
        return h2
thomwolf's avatar
WIP  
thomwolf committed
240
241


thomwolf's avatar
thomwolf committed
242
243
244
class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
        super(TFBlock, self).__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
245
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
246
247
248
249
        self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_1')
        self.attn = TFAttention(nx, n_ctx, config, scale, name='attn')
        self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
        self.mlp = TFMLP(4 * nx, config, name='mlp')
thomwolf's avatar
WIP  
thomwolf committed
250

thomwolf's avatar
thomwolf committed
251
    @tf.function
thomwolf's avatar
thomwolf committed
252
253
    def call(self, inputs, training=False):
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
254

thomwolf's avatar
thomwolf committed
255
256
257
        a = self.ln_1(x)
        output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
        a = output_attn[0]  # output_attn: a, present, (attentions)
thomwolf's avatar
WIP  
thomwolf committed
258
        x = x + a
thomwolf's avatar
thomwolf committed
259
260
261

        m = self.ln_2(x)
        m = self.mlp(m, training=training)
thomwolf's avatar
WIP  
thomwolf committed
262
263
264
265
266
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)

thomwolf's avatar
thomwolf committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class TFGPT2Embeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config, **kwargs):
        super(TFGPT2Embeddings, self).__init__(**kwargs)
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size

    def build(self, input_shape):
        """Build shared word embedding layer
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        self.weight = self.add_weight(
            "weight",
thomwolf's avatar
thomwolf committed
282
            shape=[self.vocab_size, self.hidden_size],
thomwolf's avatar
thomwolf committed
283
            initializer=tf.random_normal_initializer(
thomwolf's avatar
thomwolf committed
284
285
                mean=0., stddev=self.hidden_size**-0.5))
        super(TFGPT2Embeddings, self).build(input_shape)
thomwolf's avatar
thomwolf committed
286
287

    @tf.function
thomwolf's avatar
thomwolf committed
288
    def call(self, inputs, mode="embedding"):
thomwolf's avatar
thomwolf committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
        
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
thomwolf's avatar
thomwolf committed
304
            return self._embedding(inputs)
thomwolf's avatar
thomwolf committed
305
306
307
308
309
310
311
312
313
314
315
316
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
317
                inputs: A float32 tensor with shape [..., hidden_size]
thomwolf's avatar
thomwolf committed
318
            Returns:
319
                float32 tensor with shape [..., vocab_size].
thomwolf's avatar
thomwolf committed
320
        """
321
        first_dims = shape_list(inputs)[:-1]
thomwolf's avatar
thomwolf committed
322

thomwolf's avatar
thomwolf committed
323
        x = tf.reshape(inputs, [-1, self.hidden_size])
thomwolf's avatar
thomwolf committed
324
325
        logits = tf.matmul(x, self.weight, transpose_b=True)

326
        return tf.reshape(logits, first_dims + [self.vocab_size])
thomwolf's avatar
thomwolf committed
327
328
329
330
331
332

class TFGPT2MainLayer(tf.keras.layers.Layer):
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2MainLayer, self).__init__(config, *inputs, **kwargs)
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
thomwolf's avatar
thomwolf committed
333
        self.num_hidden_layers = config.n_layer
thomwolf's avatar
thomwolf committed
334
335
336
337
338
339
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd

        self.wte = TFGPT2Embeddings(config, name='wte')
        self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
        self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
thomwolf's avatar
thomwolf committed
340
        self.h = [TFBlock(config.n_ctx, config, scale=True, name='h_{}'.format(i)) for i in range(config.n_layer)]
thomwolf's avatar
thomwolf committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        raise NotImplementedError

    @tf.function
    def call(self, inputs, training=False):
        if not isinstance(inputs, (dict, tuple, list)):
            input_ids = inputs
thomwolf's avatar
thomwolf committed
356
            past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
thomwolf's avatar
thomwolf committed
357
358
        elif isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
359
360
361
362
363
364
            past = inputs[1] if len(inputs) > 1 else None
            attention_mask = inputs[2] if len(inputs) > 2 else None
            token_type_ids = inputs[3] if len(inputs) > 3 else None
            position_ids = inputs[4] if len(inputs) > 4 else None
            head_mask = inputs[5] if len(inputs) > 5 else None
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
365
366
        else:
            input_ids = inputs.get('input_ids')
thomwolf's avatar
thomwolf committed
367
            past = inputs.get('past', None)
thomwolf's avatar
thomwolf committed
368
369
370
371
372
373
374
375
376
377
            attention_mask = inputs.get('attention_mask', None)
            token_type_ids = inputs.get('token_type_ids', None)
            position_ids = inputs.get('position_ids', None)
            head_mask = inputs.get('head_mask', None)
            assert len(inputs) <= 5, "Too many inputs."

        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
thomwolf's avatar
thomwolf committed
378
            past_length = shape_list(past[0][0])[-2]
thomwolf's avatar
thomwolf committed
379
        if position_ids is None:
thomwolf's avatar
thomwolf committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
            position_ids = tf.range(past_length, shape_list(input_ids)[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
thomwolf's avatar
thomwolf committed
400
401
402
403

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
404
405
406
407
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if not head_mask is None:
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
408
        else:
thomwolf's avatar
thomwolf committed
409
410
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)
thomwolf's avatar
thomwolf committed
411

thomwolf's avatar
thomwolf committed
412
413
414
        input_shape = shape_list(input_ids)
        input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
thomwolf's avatar
thomwolf committed
415

thomwolf's avatar
thomwolf committed
416
        inputs_embeds = self.wte(input_ids, mode='embedding')
thomwolf's avatar
thomwolf committed
417
418
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
thomwolf's avatar
thomwolf committed
419
420
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
            token_type_embeds = self.wte(token_type_ids, mode='embedding')
thomwolf's avatar
thomwolf committed
421
422
423
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
424
425
        if training:
            hidden_states = self.drop(hidden_states)
thomwolf's avatar
thomwolf committed
426

thomwolf's avatar
thomwolf committed
427
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
thomwolf's avatar
thomwolf committed
428
429
430
431
432
433

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
434
435
436
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

            outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
thomwolf's avatar
thomwolf committed
437
438
439

            hidden_states, present = outputs[:2]
            presents = presents + (present,)
thomwolf's avatar
WIP  
thomwolf committed
440

thomwolf's avatar
thomwolf committed
441
442
443
444
445
            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

thomwolf's avatar
thomwolf committed
446
        hidden_states = tf.reshape(hidden_states, output_shape)
thomwolf's avatar
thomwolf committed
447
448
449
450
451
452
453
454
455
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states, presents)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
thomwolf's avatar
thomwolf committed
456
457
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
thomwolf's avatar
thomwolf committed
458
459
460
            outputs = outputs + (all_attentions,)
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
461

thomwolf's avatar
thomwolf committed
462
class TFGPT2PreTrainedModel(TFPreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
463
464
465
466
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = GPT2Config
thomwolf's avatar
thomwolf committed
467
468
    pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
    load_pt_weights = load_gpt2_pt_weights_in_tf
thomwolf's avatar
WIP  
thomwolf committed
469
470
471
472
473
474
475
476
477
    base_model_prefix = "transformer"


GPT2_START_DOCSTRING = r"""    OpenAI GPT-2 model was proposed in
    `Language Models are Unsupervised Multitask Learners`_
    by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
    It's a causal (unidirectional) transformer pre-trained using  language modeling on a very large
    corpus of ~40 GB of text data.

thomwolf's avatar
thomwolf committed
478
479
    This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.
thomwolf's avatar
WIP  
thomwolf committed
480
481
482
483

    .. _`Language Models are Unsupervised Multitask Learners`:
        https://openai.com/blog/better-language-models/

thomwolf's avatar
thomwolf committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    .. _`tf.keras.Model`:
        https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model

    Important note on the model inputs:
        The inputs of the TF 2.0 models are slightly different from the PyTorch ones since
        TF 2.0 Keras doesn't accept named arguments with defaults values for input Tensor.
        More precisely, input Tensors are gathered in the first arguments of the model call function: `model(inputs)`.
        There are three possibilities to gather and feed the inputs to the model:

        - a single Tensor with input_ids only and nothing else: `model(inputs_ids)
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
            `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
            `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
WIP  
thomwolf committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

    Parameters:
        config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

GPT2_INPUTS_DOCSTRING = r"""    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.
            Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **past**:
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `past` output below). Can be used to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
517
518
519
520
521
522
523
524
525
526
527
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
WIP  
thomwolf committed
528
529
530
531
532
533
534
535
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
                      GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
536
class TFGPT2Model(TFGPT2PreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
556
        model = GPT2Model.from_pretrained('gpt2')
thomwolf's avatar
WIP  
thomwolf committed
557
558
559
560
561
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
thomwolf's avatar
thomwolf committed
562
563
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
564
        self.transformer = TFGPT2MainLayer(config, name='transformer')
thomwolf's avatar
thomwolf committed
565
566
567

    @tf.function
    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
568
569
        outputs = self.transformer(inputs, training=training)
        return outputs
thomwolf's avatar
WIP  
thomwolf committed
570
571
572
573


@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """, GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
574
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import torch
        from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
thomwolf's avatar
thomwolf committed
600
601
        outputs = model(input_ids)
        logits = outputs[:2]
thomwolf's avatar
WIP  
thomwolf committed
602
603

    """
thomwolf's avatar
thomwolf committed
604
605
606
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name='transformer')
thomwolf's avatar
WIP  
thomwolf committed
607

thomwolf's avatar
thomwolf committed
608
609
610
    @tf.function
    def call(self, inputs, training=False):
        transformer_outputs = self.transformer(inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
611
612
        hidden_states = transformer_outputs[0]

thomwolf's avatar
thomwolf committed
613
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
thomwolf's avatar
WIP  
thomwolf committed
614
615
616

        outputs = (lm_logits,) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
617
        return outputs  # lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
WIP  
thomwolf committed
618
619
620
621
622
623


@add_start_docstrings("""The GPT2 Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the input sequence).
thomwolf's avatar
thomwolf committed
624
625
626
627
""", GPT2_START_DOCSTRING, GPT2_INPUTS_DOCSTRING)
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
    r"""
        **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
thomwolf's avatar
WIP  
thomwolf committed
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
            Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
        **past**:
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            that contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import torch
        from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
        
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
        
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
        
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_prediction_scores, mc_prediction_scores = outputs[:2]

    """
thomwolf's avatar
thomwolf committed
672
673
674
675
676
677
678
679
680
    def __init__(self, config, *inputs, **kwargs):
        super(TFGPT2DoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name='transformer')
        self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')


    @tf.function
    def call(self, inputs, training=False):
        if not isinstance(inputs, (dict, tuple, list)):
681
682
            input_ids = inputs
            mc_token_ids, past, attention_mask, token_type_ids, position_ids, head_mask = None, None, None, None, None
thomwolf's avatar
thomwolf committed
683
684
        elif isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
685
            mc_token_ids = inputs[1] if len(inputs) > 1 else None
thomwolf's avatar
thomwolf committed
686
687
688
689
690
691
692
693
            past = inputs[2] if len(inputs) > 2 else None
            attention_mask = inputs[3] if len(inputs) > 3 else None
            token_type_ids = inputs[4] if len(inputs) > 4 else None
            position_ids = inputs[5] if len(inputs) > 5 else None
            head_mask = inputs[6] if len(inputs) > 6 else None
            assert len(inputs) <= 7, "Too many inputs."
        else:
            input_ids = inputs.get('input_ids')
694
            mc_token_ids = inputs.get('mc_token_ids', None)
thomwolf's avatar
thomwolf committed
695
696
697
698
699
700
701
            past = inputs.get('past', None)
            attention_mask = inputs.get('attention_mask', None)
            token_type_ids = inputs.get('token_type_ids', None)
            position_ids = inputs.get('position_ids', None)
            head_mask = inputs.get('head_mask', None)
            assert len(inputs) <= 5, "Too many inputs."

702
703
704
        input_shapes = shape_list(input_ids)

        seq_length = input_shapes[-1]
thomwolf's avatar
thomwolf committed
705
706
707
708
709
710
711
712

        flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

        flat_inputs = [flat_input_ids, past, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]

713
        transformer_outputs = self.transformer(flat_inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
714
715
        hidden_states = transformer_outputs[0]

716
717
        hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])

thomwolf's avatar
thomwolf committed
718
719
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
        mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
thomwolf's avatar
WIP  
thomwolf committed
720

721
722
        mc_logits = tf.squeeze(mc_logits, axis=-1)

thomwolf's avatar
WIP  
thomwolf committed
723
724
725
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)