modeling_tf_gpt2.py 32.4 KB
Newer Older
thomwolf's avatar
WIP  
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """


import logging

import numpy as np
import tensorflow as tf

Aymeric Augustin's avatar
Aymeric Augustin committed
24
from .configuration_gpt2 import GPT2Config
Lysandre's avatar
TF GPT2  
Lysandre committed
25
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
26
27
from .modeling_tf_utils import (
    TFConv1D,
Aymeric Augustin's avatar
Aymeric Augustin committed
28
    TFPreTrainedModel,
29
    TFSequenceSummary,
Aymeric Augustin's avatar
Aymeric Augustin committed
30
    TFSharedEmbeddings,
31
    get_initializer,
Aymeric Augustin's avatar
Aymeric Augustin committed
32
    shape_list,
33
)
Aymeric Augustin's avatar
Aymeric Augustin committed
34

thomwolf's avatar
WIP  
thomwolf committed
35
36
37

logger = logging.getLogger(__name__)

38
39
40
41
42
43
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5",
}
thomwolf's avatar
WIP  
thomwolf committed
44
45
46
47
48
49
50
51
52
53
54


def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
55
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
thomwolf's avatar
WIP  
thomwolf committed
56
57
58
59
    return x * cdf


class TFAttention(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
60
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
61
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
62
63
64
65
66
        self.output_attentions = config.output_attentions

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
67
        self.n_ctx = n_ctx
thomwolf's avatar
WIP  
thomwolf committed
68
69
70
71
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale

72
73
        self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
        self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
thomwolf committed
74
75
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
76
77
78
        self.pruned_heads = set()

    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
79
80
81
        pass

    @staticmethod
thomwolf's avatar
thomwolf committed
82
    def causal_attention_mask(nd, ns, dtype):
thomwolf's avatar
thomwolf committed
83
84
85
        """1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
86
        i = tf.range(nd)[:, None]
thomwolf's avatar
thomwolf committed
87
88
89
90
91
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def _attn(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
92
        q, k, v, attention_mask, head_mask = inputs
thomwolf's avatar
thomwolf committed
93
94
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
thomwolf's avatar
WIP  
thomwolf committed
95
        if self.scale:
96
            dk = tf.cast(shape_list(k)[-1], tf.float32)  # scale attention_scores
thomwolf's avatar
thomwolf committed
97
            w = w / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
98
99
100

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
thomwolf's avatar
thomwolf committed
101
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
thomwolf's avatar
thomwolf committed
102
        b = tf.reshape(b, [1, 1, nd, ns])
thomwolf's avatar
WIP  
thomwolf committed
103
104
        w = w * b - 1e4 * (1 - b)

thomwolf's avatar
thomwolf committed
105
106
107
108
109
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
thomwolf's avatar
thomwolf committed
110
        w = self.attn_dropout(w, training=training)
thomwolf's avatar
WIP  
thomwolf committed
111
112
113
114
115

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
116
        outputs = [tf.matmul(w, v)]
thomwolf's avatar
WIP  
thomwolf committed
117
118
119
120
121
        if self.output_attentions:
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
thomwolf's avatar
thomwolf committed
122
        x = tf.transpose(x, [0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
123
124
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
thomwolf's avatar
thomwolf committed
125
126
127
        return tf.reshape(x, new_x_shape)

    def split_heads(self, x):
thomwolf's avatar
thomwolf committed
128
129
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
thomwolf's avatar
thomwolf committed
130
131
132
133
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
134
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
135
136

        x = self.c_attn(x)
thomwolf's avatar
thomwolf committed
137
        query, key, value = tf.split(x, 3, axis=2)
thomwolf's avatar
WIP  
thomwolf committed
138
        query = self.split_heads(query)
thomwolf's avatar
thomwolf committed
139
        key = self.split_heads(key)
thomwolf's avatar
WIP  
thomwolf committed
140
141
        value = self.split_heads(value)
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
142
143
144
145
            past_key, past_value = tf.unstack(layer_past, axis=1)
            key = tf.concat([past_key, key], axis=-2)
            value = tf.concat([past_value, value], axis=-2)
        present = tf.stack([key, value], axis=1)
thomwolf's avatar
WIP  
thomwolf committed
146

thomwolf's avatar
thomwolf committed
147
        attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
thomwolf's avatar
WIP  
thomwolf committed
148
149
150
151
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
thomwolf's avatar
thomwolf committed
152
        a = self.resid_dropout(a, training=training)
thomwolf's avatar
WIP  
thomwolf committed
153
154
155
156
157

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)


thomwolf's avatar
thomwolf committed
158
class TFMLP(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
159
    def __init__(self, n_state, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
160
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
161
        nx = config.n_embd
162
163
        self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
        self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
WIP  
thomwolf committed
164
        self.act = gelu
thomwolf's avatar
thomwolf committed
165
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
166

thomwolf's avatar
thomwolf committed
167
    def call(self, x, training=False):
thomwolf's avatar
WIP  
thomwolf committed
168
169
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
thomwolf's avatar
thomwolf committed
170
        h2 = self.dropout(h2, training=training)
thomwolf's avatar
thomwolf committed
171
        return h2
thomwolf's avatar
WIP  
thomwolf committed
172
173


thomwolf's avatar
thomwolf committed
174
175
class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
176
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
177
        nx = config.n_embd
178
179
180
181
        self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
        self.attn = TFAttention(nx, n_ctx, config, scale, name="attn")
        self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
        self.mlp = TFMLP(4 * nx, config, name="mlp")
thomwolf's avatar
WIP  
thomwolf committed
182

thomwolf's avatar
thomwolf committed
183
184
    def call(self, inputs, training=False):
        x, layer_past, attention_mask, head_mask = inputs
thomwolf's avatar
WIP  
thomwolf committed
185

thomwolf's avatar
thomwolf committed
186
187
188
        a = self.ln_1(x)
        output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
        a = output_attn[0]  # output_attn: a, present, (attentions)
thomwolf's avatar
WIP  
thomwolf committed
189
        x = x + a
thomwolf's avatar
thomwolf committed
190
191
192

        m = self.ln_2(x)
        m = self.mlp(m, training=training)
thomwolf's avatar
WIP  
thomwolf committed
193
194
195
196
197
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)

thomwolf's avatar
thomwolf committed
198
199
200

class TFGPT2MainLayer(tf.keras.layers.Layer):
    def __init__(self, config, *inputs, **kwargs):
201
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
202
203
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
thomwolf's avatar
thomwolf committed
204
        self.num_hidden_layers = config.n_layer
thomwolf's avatar
thomwolf committed
205
206
207
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd

208
209
210
211
212
213
214
215
216
        self.wte = TFSharedEmbeddings(
            config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
        )
        self.wpe = tf.keras.layers.Embedding(
            config.n_positions,
            config.n_embd,
            embeddings_initializer=get_initializer(config.initializer_range),
            name="wpe",
        )
thomwolf's avatar
thomwolf committed
217
        self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
218
219
        self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
        self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
thomwolf's avatar
thomwolf committed
220

221
222
223
    def get_input_embeddings(self):
        return self.wte

thomwolf's avatar
thomwolf committed
224
225
226
227
228
229
230
231
232
    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        raise NotImplementedError

233
234
235
236
237
238
239
240
241
242
243
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        training=False,
    ):
thomwolf's avatar
thomwolf committed
244
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
245
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
246
247
248
249
250
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
251
252
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            assert len(inputs) <= 7, "Too many inputs."
thomwolf's avatar
thomwolf committed
253
        elif isinstance(inputs, dict):
254
255
256
257
258
259
260
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
261
            assert len(inputs) <= 7, "Too many inputs."
thomwolf's avatar
thomwolf committed
262
263
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
264

265
266
267
268
269
270
271
272
273
274
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
            input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
275
276
277
278
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
thomwolf's avatar
thomwolf committed
279
            past_length = shape_list(past[0][0])[-2]
thomwolf's avatar
thomwolf committed
280
        if position_ids is None:
281
            position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
thomwolf's avatar
thomwolf committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
thomwolf's avatar
thomwolf committed
301
302
303
304

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
305
306
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
307
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
308
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
309
        else:
thomwolf's avatar
thomwolf committed
310
311
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)
thomwolf's avatar
thomwolf committed
312

thomwolf's avatar
thomwolf committed
313
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
thomwolf's avatar
thomwolf committed
314

315
        if inputs_embeds is None:
316
            inputs_embeds = self.wte(input_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
317
318
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
thomwolf's avatar
thomwolf committed
319
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
320
            token_type_embeds = self.wte(token_type_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
321
322
323
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
324
        hidden_states = self.drop(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
325

thomwolf's avatar
thomwolf committed
326
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
thomwolf's avatar
thomwolf committed
327
328
329
330
331
332

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
333
334
335
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

            outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
thomwolf's avatar
thomwolf committed
336
337
338

            hidden_states, present = outputs[:2]
            presents = presents + (present,)
thomwolf's avatar
WIP  
thomwolf committed
339

thomwolf's avatar
thomwolf committed
340
341
342
343
344
            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

thomwolf's avatar
thomwolf committed
345
        hidden_states = tf.reshape(hidden_states, output_shape)
thomwolf's avatar
thomwolf committed
346
347
348
349
350
351
352
353
354
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states, presents)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
thomwolf's avatar
thomwolf committed
355
356
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
thomwolf's avatar
thomwolf committed
357
358
359
            outputs = outputs + (all_attentions,)
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
360

thomwolf's avatar
thomwolf committed
361
class TFGPT2PreTrainedModel(TFPreTrainedModel):
thomwolf's avatar
WIP  
thomwolf committed
362
    """ An abstract class to handle weights initialization and
363
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
WIP  
thomwolf committed
364
    """
365

thomwolf's avatar
WIP  
thomwolf committed
366
    config_class = GPT2Config
thomwolf's avatar
thomwolf committed
367
    pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
thomwolf's avatar
WIP  
thomwolf committed
368
369
370
    base_model_prefix = "transformer"


Lysandre's avatar
Lysandre committed
371
GPT2_START_DOCSTRING = r"""
thomwolf's avatar
WIP  
thomwolf committed
372

Lysandre's avatar
TF GPT2  
Lysandre committed
373
    .. note::
thomwolf's avatar
thomwolf committed
374
375
376
377
378
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
379
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
Lysandre's avatar
TF GPT2  
Lysandre committed
380
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
381

Lysandre's avatar
Lysandre committed
382
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
TF GPT2  
Lysandre committed
383
        in the first positional argument :
thomwolf's avatar
thomwolf committed
384

Lysandre's avatar
TF GPT2  
Lysandre committed
385
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
386
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
TF GPT2  
Lysandre committed
387
388
389
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
WIP  
thomwolf committed
390
391

    Parameters:
392
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
thomwolf's avatar
WIP  
thomwolf committed
393
            Initializing with a config file does not load the weights associated with the model, only the configuration.
394
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
WIP  
thomwolf committed
395
396
"""

Lysandre's avatar
Lysandre committed
397
GPT2_INPUTS_DOCSTRING = r"""
Lysandre's avatar
TF GPT2  
Lysandre committed
398
399
    Args:
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Lysandre's avatar
Lysandre committed
400
401
            Indices of input sequence tokens in the vocabulary.

Lysandre's avatar
TF GPT2  
Lysandre committed
402
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
403
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
TF GPT2  
Lysandre committed
404
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
405

Lysandre's avatar
TF GPT2  
Lysandre committed
406
407
408
409
410
411
            `What are input IDs? <../glossary.html#input-ids>`__
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
412
413
414
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
415

Lysandre's avatar
TF GPT2  
Lysandre committed
416
            `What are attention masks? <../glossary.html#attention-mask>`__
Lysandre's avatar
Lysandre committed
417
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
TF GPT2  
Lysandre committed
418
419
420
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
421

Lysandre's avatar
TF GPT2  
Lysandre committed
422
423
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
424
425
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
426

Lysandre's avatar
TF GPT2  
Lysandre committed
427
428
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
WIP  
thomwolf committed
429
430
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
TF GPT2  
Lysandre committed
431
432
433
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
434
435
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
436
437
438
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
thomwolf's avatar
WIP  
thomwolf committed
439
440
"""

441
442
443
444
445

@add_start_docstrings(
    "The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
446
class TFGPT2Model(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
447
448
449
450
451
452
453
454
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
455
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
456
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
thomwolf's avatar
WIP  
thomwolf committed
457
            Sequence of hidden-states at the last layer of the model.
Lysandre's avatar
TF GPT2  
Lysandre committed
458
459
460
461
462
463
464
465
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
466
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
467
468
469
470
471
472
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
WIP  
thomwolf committed
473
474
475

    Examples::

thomwolf's avatar
thomwolf committed
476
        import tensorflow as tf
477
        from transformers import GPT2Tokenizer, TFGPT2Model
thomwolf's avatar
thomwolf committed
478

thomwolf's avatar
WIP  
thomwolf committed
479
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
480
        model = TFGPT2Model.from_pretrained('gpt2')
481
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
WIP  
thomwolf committed
482
483
484
485
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
thomwolf's avatar
thomwolf committed
486
        outputs = self.transformer(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
487
        return outputs
thomwolf's avatar
WIP  
thomwolf committed
488
489


490
491
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling head on top
Lysandre's avatar
TF GPT2  
Lysandre committed
492
    (linear layer with weights tied to the input embeddings). """,
493
494
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
495
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
496
497
498
499
500
501
502
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    def get_output_embeddings(self):
        return self.transformer.wte

503
504
505
506
507
508
509
    def prepare_inputs_for_generation(self, inputs, past, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            inputs = tf.expand_dims(inputs[:, -1], -1)

        return {"inputs": inputs, "past": past}

Lysandre's avatar
TF GPT2  
Lysandre committed
510
511
512
513
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
514
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
515
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
WIP  
thomwolf committed
516
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
TF GPT2  
Lysandre committed
517
518
519
520
521
522
523
524
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
525
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
526
527
528
529
530
531
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
WIP  
thomwolf committed
532
533
534

    Examples::

thomwolf's avatar
thomwolf committed
535
        import tensorflow as tf
536
        from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
thomwolf's avatar
WIP  
thomwolf committed
537
538

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
539
        model = TFGPT2LMHeadModel.from_pretrained('gpt2')
thomwolf's avatar
WIP  
thomwolf committed
540

541
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
542
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
543
        logits = outputs[0]
thomwolf's avatar
WIP  
thomwolf committed
544

Lysandre's avatar
TF GPT2  
Lysandre committed
545
        """
thomwolf's avatar
thomwolf committed
546
        transformer_outputs = self.transformer(inputs, **kwargs)
thomwolf's avatar
WIP  
thomwolf committed
547
548
        hidden_states = transformer_outputs[0]

thomwolf's avatar
thomwolf committed
549
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
thomwolf's avatar
WIP  
thomwolf committed
550
551
552

        outputs = (lm_logits,) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
553
        return outputs  # lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
WIP  
thomwolf committed
554
555


556
557
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
Lysandre's avatar
TF GPT2  
Lysandre committed
558
559
560
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
561
562
563
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
564
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        config.num_labels = 1
        self.transformer = TFGPT2MainLayer(config, name="transformer")
        self.multiple_choice_head = TFSequenceSummary(
            config, initializer_range=config.initializer_range, name="multiple_choice_head"
        )

    def get_output_embeddings(self):
        return self.transformer.wte

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        training=False,
    ):
        r"""
        mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
WIP  
thomwolf committed
591
592
593
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.

Lysandre's avatar
TF GPT2  
Lysandre committed
594
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
595
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Lysandre's avatar
TF GPT2  
Lysandre committed
596
        lm_prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
WIP  
thomwolf committed
597
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
TF GPT2  
Lysandre committed
598
599
600
601
602
603
604
605
606
607
        mc_prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
WIP  
thomwolf committed
608
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
TF GPT2  
Lysandre committed
609
610
611
612
613
614
615
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

thomwolf's avatar
WIP  
thomwolf committed
616
617
618

    Examples::

Lysandre's avatar
Lysandre committed
619
        # For example purposes. Not runnable.
thomwolf's avatar
thomwolf committed
620
        import tensorflow as tf
621
        from transformers import GPT2Tokenizer, TFGPT2DoubleHeadsModel
thomwolf's avatar
thomwolf committed
622

thomwolf's avatar
WIP  
thomwolf committed
623
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
624
        model = TFGPT2DoubleHeadsModel.from_pretrained('gpt2')
625

thomwolf's avatar
WIP  
thomwolf committed
626
        # Add a [CLS] to the vocabulary (we should train it also!)
thomwolf's avatar
thomwolf committed
627
628
        # This option is currently not implemented in TF 2.0
        raise NotImplementedError
thomwolf's avatar
WIP  
thomwolf committed
629
630
631
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
632

thomwolf's avatar
WIP  
thomwolf committed
633
634
635
636
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

thomwolf's avatar
thomwolf committed
637
638
        input_ids = tf.constant(encoded_choices)[None, :]  # Batch size: 1, number of choices: 2
        mc_token_ids = tf.constant([cls_token_location])  # Batch size: 1
thomwolf's avatar
WIP  
thomwolf committed
639
640
641
642

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_prediction_scores, mc_prediction_scores = outputs[:2]

Lysandre's avatar
TF GPT2  
Lysandre committed
643
        """
thomwolf's avatar
thomwolf committed
644
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
645
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
646
647
648
649
650
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
651
652
653
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
            assert len(inputs) <= 8, "Too many inputs."
thomwolf's avatar
thomwolf committed
654
        elif isinstance(inputs, dict):
655
656
657
658
659
660
661
662
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
663
            assert len(inputs) <= 8, "Too many inputs."
thomwolf's avatar
thomwolf committed
664
665
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
666

667
668
669
670
        if input_ids is not None:
            input_shapes = shape_list(input_ids)
        else:
            input_shapes = shape_list(inputs_embeds)[:-1]
671
672

        seq_length = input_shapes[-1]
thomwolf's avatar
thomwolf committed
673

674
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
675
676
677
678
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

679
680
681
682
683
684
685
686
687
        flat_inputs = [
            flat_input_ids,
            past,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            inputs_embeds,
        ]
thomwolf's avatar
thomwolf committed
688

689
        transformer_outputs = self.transformer(flat_inputs, training=training)
thomwolf's avatar
WIP  
thomwolf committed
690
691
        hidden_states = transformer_outputs[0]

692
693
        hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])

thomwolf's avatar
thomwolf committed
694
695
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
        mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
thomwolf's avatar
WIP  
thomwolf committed
696

697
698
        mc_logits = tf.squeeze(mc_logits, axis=-1)

thomwolf's avatar
WIP  
thomwolf committed
699
700
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]

thomwolf's avatar
thomwolf committed
701
        return outputs  # lm logits, mc logits, presents, (all hidden_states), (attentions)