"tests/test_modeling_flax_gptj.py" did not exist on "70996a5420f6b28cb0330e373b99f75893c8fbb3"
modeling_tf_gpt2.py 34.9 KB
Newer Older
thomwolf's avatar
WIP  
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """


Sylvain Gugger's avatar
Sylvain Gugger committed
19
20
from dataclasses import dataclass
from typing import List, Optional, Tuple
thomwolf's avatar
WIP  
thomwolf committed
21
22
23
24

import numpy as np
import tensorflow as tf

Aymeric Augustin's avatar
Aymeric Augustin committed
25
from .configuration_gpt2 import GPT2Config
Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
28
29
30
31
32
33
from .file_utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
    replace_return_docstrings,
)
from .modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
34
from .modeling_tf_utils import (
35
    TFCausalLanguageModelingLoss,
36
    TFConv1D,
Aymeric Augustin's avatar
Aymeric Augustin committed
37
    TFPreTrainedModel,
38
    TFSequenceSummary,
Aymeric Augustin's avatar
Aymeric Augustin committed
39
    TFSharedEmbeddings,
40
    get_initializer,
41
    keras_serializable,
Aymeric Augustin's avatar
Aymeric Augustin committed
42
    shape_list,
43
)
44
from .tokenization_utils import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
45
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
46

thomwolf's avatar
WIP  
thomwolf committed
47

Lysandre Debut's avatar
Lysandre Debut committed
48
logger = logging.get_logger(__name__)
thomwolf's avatar
WIP  
thomwolf committed
49

Sylvain Gugger's avatar
Sylvain Gugger committed
50
_CONFIG_FOR_DOC = "GPT2Config"
51
52
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

53
54
55
56
57
58
59
60
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "distilgpt2",
    # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
thomwolf's avatar
WIP  
thomwolf committed
61
62
63
64
65
66
67
68
69
70
71


def gelu(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
72
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
thomwolf's avatar
WIP  
thomwolf committed
73
74
75
76
    return x * cdf


class TFAttention(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
77
    def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
78
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
79
80
81
82

        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
83
        self.n_ctx = n_ctx
thomwolf's avatar
WIP  
thomwolf committed
84
85
86
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
Julien Plu's avatar
Julien Plu committed
87
        self.output_attentions = config.output_attentions
thomwolf's avatar
WIP  
thomwolf committed
88

89
90
        self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
        self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
thomwolf committed
91
92
        self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
        self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
93
94
95
        self.pruned_heads = set()

    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
96
97
98
        pass

    @staticmethod
thomwolf's avatar
thomwolf committed
99
    def causal_attention_mask(nd, ns, dtype):
thomwolf's avatar
thomwolf committed
100
101
102
        """1's in the lower triangle, counting from the lower right corner.
        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
103
        i = tf.range(nd)[:, None]
thomwolf's avatar
thomwolf committed
104
105
106
107
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

Julien Plu's avatar
Julien Plu committed
108
    def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
thomwolf's avatar
thomwolf committed
109
110
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
thomwolf's avatar
WIP  
thomwolf committed
111
        if self.scale:
112
            dk = tf.cast(shape_list(k)[-1], tf.float32)  # scale attention_scores
thomwolf's avatar
thomwolf committed
113
            w = w / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
114
115
116

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
thomwolf's avatar
thomwolf committed
117
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
thomwolf's avatar
thomwolf committed
118
        b = tf.reshape(b, [1, 1, nd, ns])
thomwolf's avatar
WIP  
thomwolf committed
119
120
        w = w * b - 1e4 * (1 - b)

thomwolf's avatar
thomwolf committed
121
122
123
124
125
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
thomwolf's avatar
thomwolf committed
126
        w = self.attn_dropout(w, training=training)
thomwolf's avatar
WIP  
thomwolf committed
127
128
129
130
131

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
132
        outputs = [tf.matmul(w, v)]
Julien Plu's avatar
Julien Plu committed
133
        if output_attentions:
thomwolf's avatar
WIP  
thomwolf committed
134
135
136
137
            outputs.append(w)
        return outputs

    def merge_heads(self, x):
thomwolf's avatar
thomwolf committed
138
        x = tf.transpose(x, [0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
139
140
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
thomwolf's avatar
thomwolf committed
141
142
143
        return tf.reshape(x, new_x_shape)

    def split_heads(self, x):
thomwolf's avatar
thomwolf committed
144
145
        x_shape = shape_list(x)
        new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
thomwolf's avatar
thomwolf committed
146
147
148
        x = tf.reshape(x, new_x_shape)
        return tf.transpose(x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)

Julien Plu's avatar
Julien Plu committed
149
    def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
thomwolf's avatar
WIP  
thomwolf committed
150
        x = self.c_attn(x)
thomwolf's avatar
thomwolf committed
151
        query, key, value = tf.split(x, 3, axis=2)
thomwolf's avatar
WIP  
thomwolf committed
152
        query = self.split_heads(query)
thomwolf's avatar
thomwolf committed
153
        key = self.split_heads(key)
thomwolf's avatar
WIP  
thomwolf committed
154
155
        value = self.split_heads(value)
        if layer_past is not None:
156
            past_key, past_value = tf.unstack(layer_past, axis=0)
thomwolf's avatar
thomwolf committed
157
158
            key = tf.concat([past_key, key], axis=-2)
            value = tf.concat([past_value, value], axis=-2)
159
160

        # to cope with keras serialization
Julien Plu's avatar
Julien Plu committed
161
        if use_cache:
162
163
164
            present = tf.stack([key, value], axis=0)
        else:
            present = (None,)
thomwolf's avatar
WIP  
thomwolf committed
165

Julien Plu's avatar
Julien Plu committed
166
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
thomwolf's avatar
WIP  
thomwolf committed
167
168
169
170
        a = attn_outputs[0]

        a = self.merge_heads(a)
        a = self.c_proj(a)
thomwolf's avatar
thomwolf committed
171
        a = self.resid_dropout(a, training=training)
thomwolf's avatar
WIP  
thomwolf committed
172
173
174
175
176

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)


thomwolf's avatar
thomwolf committed
177
class TFMLP(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
178
    def __init__(self, n_state, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
179
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
180
        nx = config.n_embd
181
182
        self.c_fc = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_fc")
        self.c_proj = TFConv1D(nx, n_state, initializer_range=config.initializer_range, name="c_proj")
thomwolf's avatar
WIP  
thomwolf committed
183
        self.act = gelu
thomwolf's avatar
thomwolf committed
184
        self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
thomwolf's avatar
WIP  
thomwolf committed
185

thomwolf's avatar
thomwolf committed
186
    def call(self, x, training=False):
thomwolf's avatar
WIP  
thomwolf committed
187
188
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
thomwolf's avatar
thomwolf committed
189
        h2 = self.dropout(h2, training=training)
thomwolf's avatar
thomwolf committed
190
        return h2
thomwolf's avatar
WIP  
thomwolf committed
191
192


thomwolf's avatar
thomwolf committed
193
194
class TFBlock(tf.keras.layers.Layer):
    def __init__(self, n_ctx, config, scale=False, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
195
        super().__init__(**kwargs)
thomwolf's avatar
WIP  
thomwolf committed
196
        nx = config.n_embd
197
        inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
198
199
200
        self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
        self.attn = TFAttention(nx, n_ctx, config, scale, name="attn")
        self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
201
        self.mlp = TFMLP(inner_dim, config, name="mlp")
thomwolf's avatar
WIP  
thomwolf committed
202

Julien Plu's avatar
Julien Plu committed
203
    def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
thomwolf's avatar
thomwolf committed
204
        a = self.ln_1(x)
205
        output_attn = self.attn(
Julien Plu's avatar
Julien Plu committed
206
            a, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=training
207
        )
thomwolf's avatar
thomwolf committed
208
        a = output_attn[0]  # output_attn: a, present, (attentions)
thomwolf's avatar
WIP  
thomwolf committed
209
        x = x + a
thomwolf's avatar
thomwolf committed
210
211
212

        m = self.ln_2(x)
        m = self.mlp(m, training=training)
thomwolf's avatar
WIP  
thomwolf committed
213
214
215
216
217
        x = x + m

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)

thomwolf's avatar
thomwolf committed
218

219
@keras_serializable
thomwolf's avatar
thomwolf committed
220
class TFGPT2MainLayer(tf.keras.layers.Layer):
221
222
    config_class = GPT2Config

thomwolf's avatar
thomwolf committed
223
    def __init__(self, config, *inputs, **kwargs):
224
        super().__init__(*inputs, **kwargs)
thomwolf's avatar
thomwolf committed
225
        self.output_attentions = config.output_attentions
226
        self.output_hidden_states = config.output_hidden_states
227
        self.use_cache = config.use_cache
Sylvain Gugger's avatar
Sylvain Gugger committed
228
        self.return_dict = config.use_return_dict
229

thomwolf's avatar
thomwolf committed
230
        self.num_hidden_layers = config.n_layer
thomwolf's avatar
thomwolf committed
231
232
233
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd

234
235
236
237
238
239
240
241
242
        self.wte = TFSharedEmbeddings(
            config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
        )
        self.wpe = tf.keras.layers.Embedding(
            config.n_positions,
            config.n_embd,
            embeddings_initializer=get_initializer(config.initializer_range),
            name="wpe",
        )
thomwolf's avatar
thomwolf committed
243
        self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
244
245
        self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
        self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
thomwolf's avatar
thomwolf committed
246

247
248
249
    def get_input_embeddings(self):
        return self.wte

250
251
252
    def set_input_embeddings(self, value):
        self.wte.weight = value
        self.wte.vocab_size = self.wte.weight.shape[0]
thomwolf's avatar
thomwolf committed
253
254

    def _prune_heads(self, heads_to_prune):
Lysandre's avatar
Lysandre committed
255
256
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
257
258
259
        """
        raise NotImplementedError

260
261
262
263
264
265
266
267
268
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
269
        use_cache=None,
270
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
271
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
272
        return_dict=None,
273
        training=False,
274
    ):
thomwolf's avatar
thomwolf committed
275
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
276
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
277
278
279
280
281
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
282
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
283
            use_cache = inputs[7] if len(inputs) > 7 else use_cache
284
285
            output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
            output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
286
287
            return_dict = inputs[10] if len(inputs) > 10 else return_dict
            assert len(inputs) <= 11, "Too many inputs."
288
        elif isinstance(inputs, (dict, BatchEncoding)):
289
290
291
292
293
294
295
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
296
            use_cache = inputs.get("use_cache", use_cache)
297
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
298
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
299
300
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 11, "Too many inputs."
thomwolf's avatar
thomwolf committed
301
302
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
303

304
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
Joseph Liu's avatar
Joseph Liu committed
305
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
306
        use_cache = use_cache if use_cache is not None else self.use_cache
Sylvain Gugger's avatar
Sylvain Gugger committed
307
        return_dict = return_dict if return_dict is not None else self.return_dict
308

309
310
311
312
313
314
315
316
317
318
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
            input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
319
320
321
322
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
thomwolf's avatar
thomwolf committed
323
            past_length = shape_list(past[0][0])[-2]
thomwolf's avatar
thomwolf committed
324
        if position_ids is None:
325
            position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]
thomwolf's avatar
thomwolf committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None
thomwolf's avatar
thomwolf committed
345
346
347
348

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
349
350
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
351
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
352
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
353
        else:
thomwolf's avatar
thomwolf committed
354
355
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)
thomwolf's avatar
thomwolf committed
356

thomwolf's avatar
thomwolf committed
357
        position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
thomwolf's avatar
thomwolf committed
358

359
        if inputs_embeds is None:
360
            inputs_embeds = self.wte(input_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
361
362
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
thomwolf's avatar
thomwolf committed
363
            token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
364
            token_type_embeds = self.wte(token_type_ids, mode="embedding")
thomwolf's avatar
thomwolf committed
365
366
367
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
368
        hidden_states = self.drop(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
369

thomwolf's avatar
thomwolf committed
370
        output_shape = input_shape + [shape_list(hidden_states)[-1]]
thomwolf's avatar
thomwolf committed
371

Sylvain Gugger's avatar
Sylvain Gugger committed
372
373
374
        presents = () if use_cache else None
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
thomwolf's avatar
thomwolf committed
375
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
Julien Plu's avatar
Julien Plu committed
376
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
377
378
                all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)

379
            outputs = block(
Julien Plu's avatar
Julien Plu committed
380
381
382
383
384
385
                hidden_states,
                layer_past,
                attention_mask,
                head_mask[i],
                use_cache,
                output_attentions,
386
387
                training=training,
            )
thomwolf's avatar
thomwolf committed
388
389

            hidden_states, present = outputs[:2]
Sylvain Gugger's avatar
Sylvain Gugger committed
390
391
            if use_cache:
                presents = presents + (present,)
thomwolf's avatar
WIP  
thomwolf committed
392

Julien Plu's avatar
Julien Plu committed
393
            if output_attentions:
Sylvain Gugger's avatar
Sylvain Gugger committed
394
                all_attentions = all_attentions + (outputs[2],)
thomwolf's avatar
thomwolf committed
395
396
397

        hidden_states = self.ln_f(hidden_states)

thomwolf's avatar
thomwolf committed
398
        hidden_states = tf.reshape(hidden_states, output_shape)
thomwolf's avatar
thomwolf committed
399
        # Add last hidden state
Julien Plu's avatar
Julien Plu committed
400
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
401
402
            all_hidden_states = all_hidden_states + (hidden_states,)

Julien Plu's avatar
Julien Plu committed
403
        if output_attentions:
thomwolf's avatar
thomwolf committed
404
            # let the number of heads free (-1) so we can extract attention even after head pruning
thomwolf's avatar
thomwolf committed
405
406
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
            all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
Sylvain Gugger's avatar
Sylvain Gugger committed
407
408
409
410
411
412
413
414
415
416

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)

        return TFBaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )
thomwolf's avatar
thomwolf committed
417

thomwolf's avatar
thomwolf committed
418

thomwolf's avatar
thomwolf committed
419
class TFGPT2PreTrainedModel(TFPreTrainedModel):
Lysandre's avatar
Lysandre committed
420
421
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
thomwolf's avatar
WIP  
thomwolf committed
422
    """
423

thomwolf's avatar
WIP  
thomwolf committed
424
425
426
427
    config_class = GPT2Config
    base_model_prefix = "transformer"


Sylvain Gugger's avatar
Sylvain Gugger committed
428
429
430
431
432
433
@dataclass
class TFGPT2DoubleHeadsModelOutput(ModelOutput):
    """
    Base class for outputs of models predicting if two sentences are consecutive or not.

    Args:
434
        logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
            List of :obj:`tf.Tensor` of length :obj:`config.n_layers`,  with each tensor of shape
            :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            ``past_key_values`` input) to speed up sequential decoding.
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

457
    logits: tf.Tensor = None
Sylvain Gugger's avatar
Sylvain Gugger committed
458
459
460
461
462
463
    mc_logits: tf.Tensor = None
    past_key_values: Optional[List[tf.Tensor]] = None
    hidden_states: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[tf.Tensor]] = None


Lysandre's avatar
Lysandre committed
464
GPT2_START_DOCSTRING = r"""
thomwolf's avatar
WIP  
thomwolf committed
465

Lysandre's avatar
TF GPT2  
Lysandre committed
466
    .. note::
thomwolf's avatar
thomwolf committed
467
468
469
470
471
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
472
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
Lysandre's avatar
TF GPT2  
Lysandre committed
473
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
474

Lysandre's avatar
Lysandre committed
475
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
TF GPT2  
Lysandre committed
476
        in the first positional argument :
thomwolf's avatar
thomwolf committed
477

Lysandre's avatar
TF GPT2  
Lysandre committed
478
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
479
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
TF GPT2  
Lysandre committed
480
481
482
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
WIP  
thomwolf committed
483
484

    Parameters:
485
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
thomwolf's avatar
WIP  
thomwolf committed
486
            Initializing with a config file does not load the weights associated with the model, only the configuration.
487
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
WIP  
thomwolf committed
488
489
"""

Lysandre's avatar
Lysandre committed
490
GPT2_INPUTS_DOCSTRING = r"""
Lysandre's avatar
TF GPT2  
Lysandre committed
491
    Args:
492
493
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
494
            Indices of input sequence tokens in the vocabulary.
495
496

            If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
Lysandre's avatar
Lysandre committed
497

Lysandre's avatar
TF GPT2  
Lysandre committed
498
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
499
            See :func:`transformers.PreTrainedTokenizer.encode` and
500
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
Lysandre's avatar
Lysandre committed
501

Lysandre's avatar
TF GPT2  
Lysandre committed
502
503
504
            `What are input IDs? <../glossary.html#input-ids>`__
        past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
505
506
507
            (see `past` output below). Can be used to speed up sequential decoding.
            The token ids which have their past given to this model
            should not be passed as `input_ids` as they have already been computed.
508
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`):
thomwolf's avatar
thomwolf committed
509
510
511
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
512

Lysandre's avatar
TF GPT2  
Lysandre committed
513
            `What are attention masks? <../glossary.html#attention-mask>`__
514
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Lysandre's avatar
TF GPT2  
Lysandre committed
515
516
517
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
518

Lysandre's avatar
TF GPT2  
Lysandre committed
519
            `What are token type IDs? <../glossary.html#token-type-ids>`_
520
        position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`):
thomwolf's avatar
thomwolf committed
521
522
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
523

Lysandre's avatar
TF GPT2  
Lysandre committed
524
            `What are position IDs? <../glossary.html#position-ids>`_
525
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
thomwolf's avatar
WIP  
thomwolf committed
526
527
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
TF GPT2  
Lysandre committed
528
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
529
        inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Lysandre's avatar
TF GPT2  
Lysandre committed
530
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
531
532
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
533
534
535
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
536
        output_attentions (:obj:`bool`, `optional`):
537
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
538
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
539
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
540
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
541
542
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
            plain tuple.
thomwolf's avatar
WIP  
thomwolf committed
543
544
"""

545
546
547
548
549

@add_start_docstrings(
    "The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
550
class TFGPT2Model(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
551
552
553
554
555
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
Sylvain Gugger's avatar
Sylvain Gugger committed
556
557
558
559
560
561
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="gpt2",
        output_type=TFBaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
Lysandre's avatar
TF GPT2  
Lysandre committed
562
    def call(self, inputs, **kwargs):
thomwolf's avatar
thomwolf committed
563
        outputs = self.transformer(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
564
        return outputs
thomwolf's avatar
WIP  
thomwolf committed
565
566


567
568
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling head on top
Lysandre's avatar
TF GPT2  
Lysandre committed
569
    (linear layer with weights tied to the input embeddings). """,
570
571
    GPT2_START_DOCSTRING,
)
572
class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
Lysandre's avatar
TF GPT2  
Lysandre committed
573
574
575
576
577
578
579
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.transformer = TFGPT2MainLayer(config, name="transformer")

    def get_output_embeddings(self):
        return self.transformer.wte

580
581
582
583
584
    def prepare_inputs_for_generation(self, inputs, past, **kwargs):
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            inputs = tf.expand_dims(inputs[:, -1], -1)

585
        return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
586

Lysandre's avatar
TF GPT2  
Lysandre committed
587
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
Sylvain Gugger's avatar
Sylvain Gugger committed
588
589
590
591
592
593
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="gpt2",
        output_type=TFCausalLMOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
594
595
596
597
598
599
600
601
602
603
604
605
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
606
        return_dict=None,
607
608
609
        labels=None,
        training=False,
    ):
Lysandre's avatar
TF GPT2  
Lysandre committed
610
        r"""
611
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
612
613
            Labels for computing the cross entropy classification loss.
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
Lysandre's avatar
TF GPT2  
Lysandre committed
614
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
615
        return_dict = return_dict if return_dict is not None else self.transformer.return_dict
616
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
617
618
619
            labels = inputs[11] if len(inputs) > 11 else labels
            if len(inputs) > 11:
                inputs = inputs[:11]
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        transformer_outputs = self.transformer(
            inputs,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
634
            return_dict=return_dict,
635
636
637
            training=training,
        )

thomwolf's avatar
WIP  
thomwolf committed
638
639
        hidden_states = transformer_outputs[0]

640
        logits = self.transformer.wte(hidden_states, mode="linear")
thomwolf's avatar
WIP  
thomwolf committed
641

Sylvain Gugger's avatar
Sylvain Gugger committed
642
        loss = None
643
644
645
646
647
        if labels is not None:
            # shift labels to the left and cut last logit token
            logits = logits[:, :-1]
            labels = labels[:, 1:]
            loss = self.compute_loss(labels, logits)
thomwolf's avatar
WIP  
thomwolf committed
648

Sylvain Gugger's avatar
Sylvain Gugger committed
649
650
651
652
653
654
655
656
657
658
659
        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return TFCausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
thomwolf's avatar
WIP  
thomwolf committed
660
661


662
663
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
Lysandre's avatar
TF GPT2  
Lysandre committed
664
665
666
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
667
668
669
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
670
class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
Lysandre's avatar
TF GPT2  
Lysandre committed
671
672
673
674
675
676
677
678
679
680
681
682
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        config.num_labels = 1
        self.transformer = TFGPT2MainLayer(config, name="transformer")
        self.multiple_choice_head = TFSequenceSummary(
            config, initializer_range=config.initializer_range, name="multiple_choice_head"
        )

    def get_output_embeddings(self):
        return self.transformer.wte

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
Sylvain Gugger's avatar
Sylvain Gugger committed
683
    @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
Lysandre's avatar
TF GPT2  
Lysandre committed
684
685
686
687
688
689
690
691
692
693
    def call(
        self,
        inputs,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
694
        use_cache=None,
695
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
696
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
697
        return_dict=None,
Lysandre's avatar
TF GPT2  
Lysandre committed
698
699
700
        training=False,
    ):
        r"""
Lysandre's avatar
Lysandre committed
701
702
703
            mc_token_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
                Index of the classification token in each input sequence.
                Selected in the range ``[0, input_ids.size(-1) - 1[``.
thomwolf's avatar
WIP  
thomwolf committed
704

Lysandre's avatar
Lysandre committed
705
        Return:
Lysandre's avatar
TF GPT2  
Lysandre committed
706

Lysandre's avatar
Lysandre committed
707
        Examples::
thomwolf's avatar
WIP  
thomwolf committed
708

Lysandre's avatar
Lysandre committed
709
710
            >>> import tensorflow as tf
            >>> from transformers import GPT2Tokenizer, TFGPT2DoubleHeadsModel
thomwolf's avatar
thomwolf committed
711

Lysandre's avatar
Lysandre committed
712
713
            >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            >>> model = TFGPT2DoubleHeadsModel.from_pretrained('gpt2')
714

Lysandre's avatar
Lysandre committed
715
716
            >>> # Add a [CLS] to the vocabulary (we should train it also!)
            >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
717

Lysandre's avatar
Lysandre committed
718
            >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
719

Lysandre's avatar
Lysandre committed
720
721
722
            >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
            >>> encoded_choices = [tokenizer.encode(s) for s in choices]
            >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
thomwolf's avatar
WIP  
thomwolf committed
723

Lysandre's avatar
Lysandre committed
724
725
            >>> input_ids = tf.constant(encoded_choices)[None, :]  # Batch size: 1, number of choices: 2
            >>> mc_token_ids = tf.constant([cls_token_location])  # Batch size: 1
thomwolf's avatar
WIP  
thomwolf committed
726

Lysandre's avatar
Lysandre committed
727
728
            >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
            >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
WIP  
thomwolf committed
729

Lysandre's avatar
TF GPT2  
Lysandre committed
730
        """
thomwolf's avatar
thomwolf committed
731
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
732
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
733
734
735
736
737
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
738
739
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
740
            use_cache = inputs[8] if len(inputs) > 8 else use_cache
Sylvain Gugger's avatar
Sylvain Gugger committed
741
742
743
744
            output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
            output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
            return_dict = inputs[11] if len(inputs) > 11 else return_dict
            assert len(inputs) <= 12, "Too many inputs."
thomwolf's avatar
thomwolf committed
745
        elif isinstance(inputs, dict):
746
747
748
749
750
751
752
753
            input_ids = inputs.get("input_ids")
            past = inputs.get("past", past)
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
754
            use_cache = inputs.get("use_cache", use_cache)
755
            output_attentions = inputs.get("output_attentions", output_attentions)
Sylvain Gugger's avatar
Sylvain Gugger committed
756
757
758
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 12, "Too many inputs."
thomwolf's avatar
thomwolf committed
759
760
        else:
            input_ids = inputs
Sylvain Gugger's avatar
Sylvain Gugger committed
761
        return_dict = return_dict if return_dict is not None else self.transformer.return_dict
thomwolf's avatar
thomwolf committed
762

763
764
765
766
        if input_ids is not None:
            input_shapes = shape_list(input_ids)
        else:
            input_shapes = shape_list(inputs_embeds)[:-1]
767
768

        seq_length = input_shapes[-1]
769
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
770
771
772
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
Julien Plu's avatar
Julien Plu committed
773
        transformer_outputs = self.transformer(
774
775
776
777
778
779
780
            flat_input_ids,
            past,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            inputs_embeds,
781
            use_cache,
782
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
783
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
784
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
785
786
            training=training,
        )
thomwolf's avatar
WIP  
thomwolf committed
787
        hidden_states = transformer_outputs[0]
788
        hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
thomwolf's avatar
thomwolf committed
789
        lm_logits = self.transformer.wte(hidden_states, mode="linear")
Julien Plu's avatar
Julien Plu committed
790
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
791
        mc_logits = tf.squeeze(mc_logits, axis=-1)
thomwolf's avatar
WIP  
thomwolf committed
792

Sylvain Gugger's avatar
Sylvain Gugger committed
793
794
795
796
        if not return_dict:
            return (lm_logits, mc_logits) + transformer_outputs[1:]

        return TFGPT2DoubleHeadsModelOutput(
797
            logits=lm_logits,
Sylvain Gugger's avatar
Sylvain Gugger committed
798
799
800
801
802
            mc_logits=mc_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )