modeling_tf_gpt2.py 33.6 KB
Newer Older
thomwolf's avatar
WIP  
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """


import logging

import numpy as np
import tensorflow as tf

Aymeric Augustin's avatar
Aymeric Augustin committed
24
from .configuration_gpt2 import GPT2Config
Lysandre's avatar
TF GPT2  
Lysandre committed
25
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
26
27
from .modeling_tf_utils import (
    TFConv1D,
Aymeric Augustin's avatar
Aymeric Augustin committed
28
    TFPreTrainedModel,
29
    TFSequenceSummary,
Aymeric Augustin's avatar
Aymeric Augustin committed
30
    TFSharedEmbeddings,
31
    get_initializer,
32
    keras_serializable,
Aymeric Augustin's avatar
Aymeric Augustin committed
33
    shape_list,
34
)
35
from .tokenization_utils import BatchEncoding
Aymeric Augustin's avatar
Aymeric Augustin committed
36

thomwolf's avatar
WIP  
thomwolf committed
37
38
39

logger = logging.getLogger(__name__)

40
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
45
    "gpt2": "https://cdn.huggingface.co/gpt2-tf_model.h5",
    "gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-tf_model.h5",
    "gpt2-large": "https://cdn.huggingface.co/gpt2-large-tf_model.h5",
    "gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-tf_model.h5",
    "distilgpt2": "https://cdn.huggingface.co/distilgpt2-tf_model.h5",
46
}
thomwolf's avatar
WIP  
thomwolf committed
47
48
49
50
51
52
53
54
55
56
57


def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
58
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
thomwolf's avatar
WIP  
thomwolf committed
59
60
61
62
    return x * cdf


class TFAttention(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
63
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
64
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
65
66
67
68
69
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
70
        self.n_ctx = n_ctx
thomwolf's avatar
WIP  
thomwolf committed
71
72
73
74
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

75
76
        self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
        self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
thomwolf committed
77
78
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
79
80
81
        self.pruned_heads = set()

    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
82
83
84
        pass

    @staticmethod
thomwolf's avatar
thomwolf committed
85
    def causal_attention_mask(nd, ns, dtype):
thomwolf's avatar
thomwolf committed
86
87
88
        """1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
89
        i = tf.range(nd)[:, None]
thomwolf's avatar
thomwolf committed
90
91
92
93
94
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def _attn(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
95
        q, k, v, attention_mask, head_mask = inputs
thomwolf's avatar
thomwolf committed
96
97
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
thomwolf's avatar
WIP  
thomwolf committed
98
        if self.scale:
99
            dk = tf.cast(shape_list(k)[-1], tf.float32)  # scale attention_scores
thomwolf's avatar
thomwolf committed
100
            w = w / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
101
102
103

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
thomwolf's avatar
thomwolf committed
104
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
thomwolf's avatar
thomwolf committed
105
        b = tf.reshape(b, [1, 1, nd, ns])
thomwolf's avatar
WIP  
thomwolf committed
106
107
        w = w * b - 1e4 * (1 - b)

thomwolf's avatar
thomwolf committed
108
109
110
111
112
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
thomwolf's avatar
thomwolf committed
113
        w = self.attn_dropout(w, training=training)
thomwolf's avatar
WIP  
thomwolf committed
114
115
116
117
118

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
119
        outputs = [tf.matmul(w, v)]
thomwolf's avatar
WIP  
thomwolf committed
120
121
122
123
124
        if self.output_attentions:
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
thomwolf's avatar
thomwolf committed
125
        x = tf.transpose(x, [0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
126
127
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
thomwolf's avatar
thomwolf committed
128
129
130
        return tf.reshape(x, new_x_shape)

    def split_heads(self, x):
thomwolf's avatar
thomwolf committed
131
132
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
thomwolf's avatar
thomwolf committed
133
134
135
136
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

    def call(self, inputs, training=False):
137
        x, layer_past, attention_mask, head_mask, use_cache = inputs
thomwolf's avatar
WIP  
thomwolf committed
138
139

        x = self.c_attn(x)
thomwolf's avatar
thomwolf committed
140
        query, key, value = tf.split(x, 3, axis=2)
thomwolf's avatar
WIP  
thomwolf committed
141
        query = self.split_heads(query)
thomwolf's avatar
thomwolf committed
142
        key = self.split_heads(key)
thomwolf's avatar
WIP  
thomwolf committed
143
144
        value = self.split_heads(value)
        if layer_past is not None:
145
            past_key, past_value = tf.unstack(layer_past, axis=0)
thomwolf's avatar
thomwolf committed
146
147
            key = tf.concat([past_key, key], axis=-2)
            value = tf.concat([past_value, value], axis=-2)
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        # to cope with keras serialization
        # we need to cast `use_cache` to correct bool
        # if it is a tensor
        if tf.is_tensor(use_cache):
            if hasattr(use_cache, "numpy"):
                use_cache = bool(use_cache.numpy())
            else:
                use_cache = True

        if use_cache is True:
            present = tf.stack([key, value], axis=0)
        else:
            present = (None,)
thomwolf's avatar
WIP  
thomwolf committed
162

thomwolf's avatar
thomwolf committed
163
        attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
thomwolf's avatar
WIP  
thomwolf committed
164
165
166
167
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
thomwolf's avatar
thomwolf committed
168
        a = self.resid_dropout(a, training=training)
thomwolf's avatar
WIP  
thomwolf committed
169
170
171
172
173

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)


thomwolf's avatar
thomwolf committed
174
class TFMLP(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
175
    def __init__(self, n_state, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
176
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
177
        nx = config.n_embd
178
179
        self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
        self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
WIP  
thomwolf committed
180
        self.act = gelu
thomwolf's avatar
thomwolf committed
181
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
182

thomwolf's avatar
thomwolf committed
183
    def call(self, x, training=False):
thomwolf's avatar
WIP  
thomwolf committed
184
185
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
thomwolf's avatar
thomwolf committed
186
        h2 = self.dropout(h2, training=training)
thomwolf's avatar
thomwolf committed
187
        return h2
thomwolf's avatar
WIP  
thomwolf committed
188
189


thomwolf's avatar
thomwolf committed
190
191
class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
192
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
193
        nx = config.n_embd
194
195
196
197
        self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
        self.attn = TFAttention(nx, n_ctx, config, scale, name="attn")
        self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
        self.mlp = TFMLP(4 * nx, config, name="mlp")
thomwolf's avatar
WIP  
thomwolf committed
198

thomwolf's avatar
thomwolf committed
199
    def call(self, inputs, training=False):
200
        x, layer_past, attention_mask, head_mask, use_cache = inputs
thomwolf's avatar
WIP  
thomwolf committed
201

thomwolf's avatar
thomwolf committed
202
        a = self.ln_1(x)
203
        output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
thomwolf's avatar
thomwolf committed
204
        a = output_attn[0]  # output_attn: a, present, (attentions)
thomwolf's avatar
WIP  
thomwolf committed
205
        x = x + a
thomwolf's avatar
thomwolf committed
206
207
208

        m = self.ln_2(x)
        m = self.mlp(m, training=training)
thomwolf's avatar
WIP  
thomwolf committed
209
210
211
212
213
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)

thomwolf's avatar
thomwolf committed
214

215
@keras_serializable
thomwolf's avatar
thomwolf committed
216
class TFGPT2MainLayer(tf.keras.layers.Layer):
217
218
    config_class = GPT2Config

thomwolf's avatar
thomwolf committed
219
    def __init__(self, config, *inputs, **kwargs):
220
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
221
222
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
thomwolf's avatar
thomwolf committed
223
        self.num_hidden_layers = config.n_layer
thomwolf's avatar
thomwolf committed
224
225
226
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd

227
228
229
230
231
232
233
234
235
        self.wte = TFSharedEmbeddings(
            config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
        )
        self.wpe = tf.keras.layers.Embedding(
            config.n_positions,
            config.n_embd,
            embeddings_initializer=get_initializer(config.initializer_range),
            name="wpe",
        )
thomwolf's avatar
thomwolf committed
236
        self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
237
238
        self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
        self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
thomwolf's avatar
thomwolf committed
239

240
241
242
    def get_input_embeddings(self):
        return self.wte

thomwolf's avatar
thomwolf committed
243
244
245
246
247
248
249
250
251
    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        raise NotImplementedError

252
253
254
255
256
257
258
259
260
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
261
        use_cache=True,
262
263
        training=False,
    ):
thomwolf's avatar
thomwolf committed
264
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
265
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
266
267
268
269
270
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
271
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
272
273
            use_cache = inputs[7] if len(inputs) > 7 else use_cache
            assert len(inputs) <= 8, "Too many inputs."
274
        elif isinstance(inputs, (dict, BatchEncoding)):
275
276
277
278
279
280
281
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
282
283
            use_cache = inputs.get("use_cache", use_cache)
            assert len(inputs) <= 8, "Too many inputs."
thomwolf's avatar
thomwolf committed
284
285
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
286

287
288
289
290
291
292
293
294
295
296
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
            input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
297
298
299
300
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
thomwolf's avatar
thomwolf committed
301
            past_length = shape_list(past[0][0])[-2]
thomwolf's avatar
thomwolf committed
302
        if position_ids is None:
303
            position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
thomwolf's avatar
thomwolf committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
thomwolf's avatar
thomwolf committed
323
324
325
326

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
327
328
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
329
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
330
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
331
        else:
thomwolf's avatar
thomwolf committed
332
333
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)
thomwolf's avatar
thomwolf committed
334

thomwolf's avatar
thomwolf committed
335
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
thomwolf's avatar
thomwolf committed
336

337
        if inputs_embeds is None:
338
            inputs_embeds = self.wte(input_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
339
340
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
thomwolf's avatar
thomwolf committed
341
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
342
            token_type_embeds = self.wte(token_type_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
343
344
345
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
346
        hidden_states = self.drop(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
347

thomwolf's avatar
thomwolf committed
348
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
thomwolf's avatar
thomwolf committed
349
350
351
352
353
354

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
355
356
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

357
            outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
thomwolf's avatar
thomwolf committed
358
359
360

            hidden_states, present = outputs[:2]
            presents = presents + (present,)
thomwolf's avatar
WIP  
thomwolf committed
361

thomwolf's avatar
thomwolf committed
362
363
364
365
366
            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

thomwolf's avatar
thomwolf committed
367
        hidden_states = tf.reshape(hidden_states, output_shape)
thomwolf's avatar
thomwolf committed
368
369
370
371
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

372
373
374
375
        outputs = (hidden_states,)

        if use_cache is True:
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
376
377
378
379
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
thomwolf's avatar
thomwolf committed
380
381
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
thomwolf's avatar
thomwolf committed
382
383
384
            outputs = outputs + (all_attentions,)
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
385

thomwolf's avatar
thomwolf committed
386
class TFGPT2PreTrainedModel(TFPreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
387
    """ An abstract class to handle weights initialization and
388
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
WIP  
thomwolf committed
389
    """
390

thomwolf's avatar
WIP  
thomwolf committed
391
    config_class = GPT2Config
thomwolf's avatar
thomwolf committed
392
    pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
thomwolf's avatar
WIP  
thomwolf committed
393
394
395
    base_model_prefix = "transformer"


Lysandre's avatar
Lysandre committed
396
GPT2_START_DOCSTRING = r"""
thomwolf's avatar
WIP  
thomwolf committed
397

Lysandre's avatar
TF GPT2  
Lysandre committed
398
    .. note::
thomwolf's avatar
thomwolf committed
399
400
401
402
403
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
404
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
Lysandre's avatar
TF GPT2  
Lysandre committed
405
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
406

Lysandre's avatar
Lysandre committed
407
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
TF GPT2  
Lysandre committed
408
        in the first positional argument :
thomwolf's avatar
thomwolf committed
409

Lysandre's avatar
TF GPT2  
Lysandre committed
410
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
411
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
TF GPT2  
Lysandre committed
412
413
414
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
WIP  
thomwolf committed
415
416

    Parameters:
417
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
thomwolf's avatar
WIP  
thomwolf committed
418
            Initializing with a config file does not load the weights associated with the model, only the configuration.
419
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
WIP  
thomwolf committed
420
421
"""

Lysandre's avatar
Lysandre committed
422
GPT2_INPUTS_DOCSTRING = r"""
Lysandre's avatar
TF GPT2  
Lysandre committed
423
    Args:
424
425
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
426
            Indices of input sequence tokens in the vocabulary.
427
428

            If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
Lysandre's avatar
Lysandre committed
429

Lysandre's avatar
TF GPT2  
Lysandre committed
430
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
431
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
TF GPT2  
Lysandre committed
432
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
433

Lysandre's avatar
TF GPT2  
Lysandre committed
434
435
436
            `What are input IDs? <../glossary.html#input-ids>`__
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
437
438
439
            (see `past` output below). Can be used to speed up sequential decoding.
            The token ids which have their past given to this model
            should not be passed as `input_ids` as they have already been computed.
Lysandre's avatar
TF GPT2  
Lysandre committed
440
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
441
442
443
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
444

Lysandre's avatar
TF GPT2  
Lysandre committed
445
            `What are attention masks? <../glossary.html#attention-mask>`__
Lysandre's avatar
Lysandre committed
446
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
TF GPT2  
Lysandre committed
447
448
449
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
450

Lysandre's avatar
TF GPT2  
Lysandre committed
451
452
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
453
454
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
455

Lysandre's avatar
TF GPT2  
Lysandre committed
456
457
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
WIP  
thomwolf committed
458
459
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
TF GPT2  
Lysandre committed
460
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
flozi00's avatar
flozi00 committed
461
        inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
TF GPT2  
Lysandre committed
462
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
463
464
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
465
466
467
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
thomwolf's avatar
WIP  
thomwolf committed
468
469
"""

470
471
472
473
474

@add_start_docstrings(
    "The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
475
class TFGPT2Model(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
476
477
478
479
480
481
482
483
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
484
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
485
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
thomwolf's avatar
WIP  
thomwolf committed
486
            Sequence of hidden-states at the last layer of the model.
Lysandre's avatar
TF GPT2  
Lysandre committed
487
488
489
490
491
492
493
494
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
495
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
496
497
498
499
500
501
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
WIP  
thomwolf committed
502
503
504

    Examples::

thomwolf's avatar
thomwolf committed
505
        import tensorflow as tf
506
        from transformers import GPT2Tokenizer, TFGPT2Model
thomwolf's avatar
thomwolf committed
507

thomwolf's avatar
WIP  
thomwolf committed
508
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
509
        model = TFGPT2Model.from_pretrained('gpt2')
510
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
WIP  
thomwolf committed
511
512
513
514
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
thomwolf's avatar
thomwolf committed
515
        outputs = self.transformer(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
516
        return outputs
thomwolf's avatar
WIP  
thomwolf committed
517
518


519
520
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling head on top
Lysandre's avatar
TF GPT2  
Lysandre committed
521
    (linear layer with weights tied to the input embeddings). """,
522
523
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
524
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
525
526
527
528
529
530
531
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    def get_output_embeddings(self):
        return self.transformer.wte

532
533
534
535
536
    def prepare_inputs_for_generation(self, inputs, past, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            inputs = tf.expand_dims(inputs[:, -1], -1)

537
        return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
538

Lysandre's avatar
TF GPT2  
Lysandre committed
539
540
541
542
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
543
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
544
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
WIP  
thomwolf committed
545
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
TF GPT2  
Lysandre committed
546
547
548
549
550
551
552
553
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
554
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
555
556
557
558
559
560
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
WIP  
thomwolf committed
561
562
563

    Examples::

thomwolf's avatar
thomwolf committed
564
        import tensorflow as tf
565
        from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
thomwolf's avatar
WIP  
thomwolf committed
566
567

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
568
        model = TFGPT2LMHeadModel.from_pretrained('gpt2')
thomwolf's avatar
WIP  
thomwolf committed
569

570
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
571
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
572
        logits = outputs[0]
thomwolf's avatar
WIP  
thomwolf committed
573

Lysandre's avatar
TF GPT2  
Lysandre committed
574
        """
thomwolf's avatar
thomwolf committed
575
        transformer_outputs = self.transformer(inputs, **kwargs)
thomwolf's avatar
WIP  
thomwolf committed
576
577
        hidden_states = transformer_outputs[0]

thomwolf's avatar
thomwolf committed
578
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
thomwolf's avatar
WIP  
thomwolf committed
579
580
581

        outputs = (lm_logits,) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
582
        return outputs  # lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
WIP  
thomwolf committed
583
584


585
586
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
Lysandre's avatar
TF GPT2  
Lysandre committed
587
588
589
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
590
591
592
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
593
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        config.num_labels = 1
        self.transformer = TFGPT2MainLayer(config, name="transformer")
        self.multiple_choice_head = TFSequenceSummary(
            config, initializer_range=config.initializer_range, name="multiple_choice_head"
        )

    def get_output_embeddings(self):
        return self.transformer.wte

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
616
        use_cache=True,
Lysandre's avatar
TF GPT2  
Lysandre committed
617
618
619
620
        training=False,
    ):
        r"""
        mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
WIP  
thomwolf committed
621
622
623
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.

Lysandre's avatar
TF GPT2  
Lysandre committed
624
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
625
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
626
        lm_prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
WIP  
thomwolf committed
627
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
TF GPT2  
Lysandre committed
628
629
630
631
632
        mc_prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
633
            should not be passed as `input_ids` as they have already been computed.
Lysandre's avatar
TF GPT2  
Lysandre committed
634
635
636
637
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
638
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
639
640
641
642
643
644
645
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

thomwolf's avatar
WIP  
thomwolf committed
646
647
648

    Examples::

Lysandre's avatar
Lysandre committed
649
        # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
650
        import tensorflow as tf
651
        from transformers import GPT2Tokenizer, TFGPT2DoubleHeadsModel
thomwolf's avatar
thomwolf committed
652

thomwolf's avatar
WIP  
thomwolf committed
653
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
654
        model = TFGPT2DoubleHeadsModel.from_pretrained('gpt2')
655

thomwolf's avatar
WIP  
thomwolf committed
656
        # Add a [CLS] to the vocabulary (we should train it also!)
thomwolf's avatar
thomwolf committed
657
658
        # This option is currently not implemented in TF 2.0
        raise NotImplementedError
thomwolf's avatar
WIP  
thomwolf committed
659
660
661
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
662

thomwolf's avatar
WIP  
thomwolf committed
663
664
665
666
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

thomwolf's avatar
thomwolf committed
667
668
        input_ids = tf.constant(encoded_choices)[None, :]  # Batch size: 1, number of choices: 2
        mc_token_ids = tf.constant([cls_token_location])  # Batch size: 1
thomwolf's avatar
WIP  
thomwolf committed
669
670
671
672

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_prediction_scores, mc_prediction_scores = outputs[:2]

Lysandre's avatar
TF GPT2  
Lysandre committed
673
        """
thomwolf's avatar
thomwolf committed
674
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
675
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
676
677
678
679
680
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
681
682
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
683
684
            use_cache = inputs[8] if len(inputs) > 8 else use_cache
            assert len(inputs) <= 9, "Too many inputs."
thomwolf's avatar
thomwolf committed
685
        elif isinstance(inputs, dict):
686
687
688
689
690
691
692
693
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
694
695
            use_cache = inputs.get("use_cache", use_cache)
            assert len(inputs) <= 9, "Too many inputs."
thomwolf's avatar
thomwolf committed
696
697
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
698

699
700
701
702
        if input_ids is not None:
            input_shapes = shape_list(input_ids)
        else:
            input_shapes = shape_list(inputs_embeds)[:-1]
703
704

        seq_length = input_shapes[-1]
thomwolf's avatar
thomwolf committed
705

706
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
707
708
709
710
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

711
712
713
714
715
716
717
718
        flat_inputs = [
            flat_input_ids,
            past,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            inputs_embeds,
719
            use_cache,
720
        ]
thomwolf's avatar
thomwolf committed
721

722
        transformer_outputs = self.transformer(flat_inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
723
724
        hidden_states = transformer_outputs[0]

725
726
        hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])

thomwolf's avatar
thomwolf committed
727
728
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
        mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
thomwolf's avatar
WIP  
thomwolf committed
729

730
731
        mc_logits = tf.squeeze(mc_logits, axis=-1)

thomwolf's avatar
WIP  
thomwolf committed
732
733
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
734
        return outputs  # lm logits, mc logits, presents, (all hidden_states), (attentions)