modeling_tf_electra.py 34.6 KB
Newer Older
Lysandre Debut's avatar
Lysandre Debut committed
1
2
3
4
5
6
import logging

import tensorflow as tf

from transformers import ElectraConfig

7
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
Lysandre Debut's avatar
Lysandre Debut committed
8
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
Julien Plu's avatar
Julien Plu committed
9
10
11
12
13
14
15
from .modeling_tf_utils import (
    TFQuestionAnsweringLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
16
from .tokenization_utils import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
17
18
19
20


logger = logging.getLogger(__name__)

21
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
Lysandre Debut's avatar
Lysandre Debut committed
22

23
24
25
26
27
28
29
30
31
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/electra-small-generator",
    "google/electra-base-generator",
    "google/electra-large-generator",
    "google/electra-small-discriminator",
    "google/electra-base-discriminator",
    "google/electra-large-discriminator",
    # See all ELECTRA models at https://huggingface.co/models?filter=electra
]
Lysandre Debut's avatar
Lysandre Debut committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203


class TFElectraEmbeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = config.vocab_size
        self.embedding_size = config.embedding_size
        self.initializer_range = config.initializer_range

        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.embedding_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.embedding_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )
        super().build(input_shape)

    def call(self, inputs, mode="embedding", training=False):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.

        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs, training=training)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, inputs, training=False):
        """Applies embedding based on inputs tensor."""
        input_ids, position_ids, token_type_ids, inputs_embeds = inputs

        if input_ids is not None:
            input_shape = shape_list(input_ids)
        else:
            input_shape = shape_list(inputs_embeds)[:-1]

        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings, training=training)
        return embeddings

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [batch_size, length, hidden_size]
            Returns:
                float32 tensor with shape [batch_size, length, vocab_size].
        """
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]

        x = tf.reshape(inputs, [-1, self.embedding_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])


class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
        self.dense_prediction = tf.keras.layers.Dense(1, name="dense_prediction")
        self.config = config

    def call(self, discriminator_hidden_states, training=False):
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
        logits = tf.squeeze(self.dense_prediction(hidden_states))

        return logits


class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")

    def call(self, generator_hidden_states, training=False):
        hidden_states = self.dense(generator_hidden_states)
        hidden_states = ACT2FN["gelu"](hidden_states)
        hidden_states = self.LayerNorm(hidden_states)

        return hidden_states


class TFElectraPreTrainedModel(TFBertPreTrainedModel):

    config_class = ElectraConfig
    base_model_prefix = "electra"

    def get_extended_attention_mask(self, attention_mask, input_shape):
        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        return extended_attention_mask

    def get_head_mask(self, head_mask):
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.config.num_hidden_layers

        return head_mask


Julien Plu's avatar
Julien Plu committed
204
@keras_serializable
Lysandre Debut's avatar
Lysandre Debut committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class TFElectraMainLayer(TFElectraPreTrainedModel):

    config_class = ElectraConfig

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.embeddings = TFElectraEmbeddings(config, name="embeddings")

        if config.embedding_size != config.hidden_size:
            self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.config = config

    def get_input_embeddings(self):
        return self.embeddings

221
222
223
224
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
        self.embeddings.vocab_size = value.shape[0]

Lysandre Debut's avatar
Lysandre Debut committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        raise NotImplementedError

    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
243
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
244
        output_hidden_states=None,
Lysandre Debut's avatar
Lysandre Debut committed
245
246
247
248
249
250
251
252
253
        training=False,
    ):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
254
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
Joseph Liu's avatar
Joseph Liu committed
255
256
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
            assert len(inputs) <= 8, "Too many inputs."
257
        elif isinstance(inputs, (dict, BatchEncoding)):
Lysandre Debut's avatar
Lysandre Debut committed
258
259
260
261
262
263
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
264
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
265
266
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
            assert len(inputs) <= 8, "Too many inputs."
Lysandre Debut's avatar
Lysandre Debut committed
267
268
269
        else:
            input_ids = inputs

270
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
271
272
273
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
274

Lysandre Debut's avatar
Lysandre Debut committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)
        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
        head_mask = self.get_head_mask(head_mask)

        hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)

        if hasattr(self, "embeddings_project"):
            hidden_states = self.embeddings_project(hidden_states, training=training)

297
        hidden_states = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
298
299
            [hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
            training=training,
300
        )
Lysandre Debut's avatar
Lysandre Debut committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

        return hidden_states


ELECTRA_START_DOCSTRING = r"""
    This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
    Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

    .. note::

        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.

        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
        in the first positional argument :

        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`

    Parameters:
        config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

ELECTRA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using :class:`transformers.ElectraTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
342
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
Lysandre Debut's avatar
Lysandre Debut committed
343
344
345
346
347
348
349
350

            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.

            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
351
352
353
354
355
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`__
Lysandre Debut's avatar
Lysandre Debut committed
356
357
358
359
360
361
362
363
364
365
366
367
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.

ZhuBaohe's avatar
ZhuBaohe committed
368
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
369
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
Lysandre Debut's avatar
Lysandre Debut committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""


@add_start_docstrings(
    "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
    "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
    "hidden size and embedding size are different."
    ""
    "Both the generator and discriminator checkpoints may be loaded into this model.",
    ELECTRA_START_DOCSTRING,
)
class TFElectraModel(TFElectraPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.electra = TFElectraMainLayer(config, name="electra")

    @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
387
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
Lysandre Debut's avatar
Lysandre Debut committed
388
389
390
    def call(self, inputs, **kwargs):
        r"""
    Returns:
391
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
392
393
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
394
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Lysandre Debut's avatar
Lysandre Debut committed
395
396
397
398
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
399
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Lysandre Debut's avatar
Lysandre Debut committed
400
401
402
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

403
404
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
Lysandre Debut's avatar
Lysandre Debut committed
405
406
407
408
409
410
        """
        outputs = self.electra(inputs, **kwargs)
        return outputs


@add_start_docstrings(
411
412
    """Electra model with a binary classification head on top as used during pre-training for identifying generated
    tokens.
Lysandre Debut's avatar
Lysandre Debut committed
413

414
415
    Even though both the discriminator and generator may be loaded into this model, the discriminator is
    the only model of the two to have the correct classification head to be used for this model.""",
Lysandre Debut's avatar
Lysandre Debut committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    ELECTRA_START_DOCSTRING,
)
class TFElectraForPreTraining(TFElectraPreTrainedModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.electra = TFElectraMainLayer(config, name="electra")
        self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")

    @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
434
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
435
        output_hidden_states=None,
Lysandre Debut's avatar
Lysandre Debut committed
436
437
438
439
        training=False,
    ):
        r"""
    Returns:
440
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
441
442
        scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
            Prediction scores of the head (scores for each token before SoftMax).
443
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Lysandre Debut's avatar
Lysandre Debut committed
444
445
446
447
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
448
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Lysandre Debut's avatar
Lysandre Debut committed
449
450
451
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

452
453
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
Lysandre Debut's avatar
Lysandre Debut committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467

    Examples::

        import tensorflow as tf
        from transformers import ElectraTokenizer, TFElectraForPreTraining

        tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
        model = TFElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        outputs = model(input_ids)
        scores = outputs[0]
        """

        discriminator_hidden_states = self.electra(
468
469
470
471
472
473
474
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
475
            output_hidden_states,
476
            training=training,
Lysandre Debut's avatar
Lysandre Debut committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        )
        discriminator_sequence_output = discriminator_hidden_states[0]
        logits = self.discriminator_predictions(discriminator_sequence_output)
        output = (logits,)
        output += discriminator_hidden_states[1:]

        return output  # (loss), scores, (hidden_states), (attentions)


class TFElectraMaskedLMHead(tf.keras.layers.Layer):
    def __init__(self, config, input_embeddings, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = config.vocab_size
        self.input_embeddings = input_embeddings

    def build(self, input_shape):
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
        super().build(input_shape)

    def call(self, hidden_states, training=False):
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
        return hidden_states


@add_start_docstrings(
503
    """Electra model with a language modeling head on top.
Lysandre Debut's avatar
Lysandre Debut committed
504

505
506
    Even though both the discriminator and generator may be loaded into this model, the generator is
    the only model of the two to have been trained for the masked language modeling task.""",
Lysandre Debut's avatar
Lysandre Debut committed
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    ELECTRA_START_DOCSTRING,
)
class TFElectraForMaskedLM(TFElectraPreTrainedModel):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.vocab_size = config.vocab_size
        self.electra = TFElectraMainLayer(config, name="electra")
        self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
        if isinstance(config.hidden_act, str):
            self.activation = ACT2FN[config.hidden_act]
        else:
            self.activation = config.hidden_act
        self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")

    def get_output_embeddings(self):
        return self.generator_lm_head

    @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
526
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-generator")
Lysandre Debut's avatar
Lysandre Debut committed
527
528
529
530
531
532
533
534
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
535
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
536
        output_hidden_states=None,
Lysandre Debut's avatar
Lysandre Debut committed
537
538
539
540
        training=False,
    ):
        r"""
    Returns:
541
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
542
543
        prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
544
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Lysandre Debut's avatar
Lysandre Debut committed
545
546
547
548
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
549
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Lysandre Debut's avatar
Lysandre Debut committed
550
551
552
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

553
554
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
Lysandre Debut's avatar
Lysandre Debut committed
555
556
557
        """

        generator_hidden_states = self.electra(
558
559
560
561
562
563
564
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
565
            output_hidden_states=output_hidden_states,
566
            training=training,
Lysandre Debut's avatar
Lysandre Debut committed
567
568
569
570
571
572
573
574
575
576
577
        )
        generator_sequence_output = generator_hidden_states[0]
        prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
        prediction_scores = self.generator_lm_head(prediction_scores, training=training)
        output = (prediction_scores,)
        output += generator_hidden_states[1:]

        return output  # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)


@add_start_docstrings(
578
    """Electra model with a token classification head on top.
Lysandre Debut's avatar
Lysandre Debut committed
579

580
    Both the discriminator and generator may be loaded into this model.""",
Lysandre Debut's avatar
Lysandre Debut committed
581
582
    ELECTRA_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
583
class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
Lysandre Debut's avatar
Lysandre Debut committed
584
585
586
587
588
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)

        self.electra = TFElectraMainLayer(config, name="electra")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
Julien Plu's avatar
Julien Plu committed
589
590
591
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )
Lysandre Debut's avatar
Lysandre Debut committed
592
593

    @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
594
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
Lysandre Debut's avatar
Lysandre Debut committed
595
596
    def call(
        self,
597
        inputs=None,
Lysandre Debut's avatar
Lysandre Debut committed
598
599
600
601
602
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
603
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
604
        output_hidden_states=None,
605
        labels=None,
Lysandre Debut's avatar
Lysandre Debut committed
606
607
608
        training=False,
    ):
        r"""
Julien Plu's avatar
Julien Plu committed
609
610
611
612
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

Lysandre Debut's avatar
Lysandre Debut committed
613
    Returns:
614
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
615
616
        scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
            Classification scores (before SoftMax).
617
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Lysandre Debut's avatar
Lysandre Debut committed
618
619
620
621
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
622
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Lysandre Debut's avatar
Lysandre Debut committed
623
624
625
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

626
627
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
Lysandre Debut's avatar
Lysandre Debut committed
628
        """
629
630
631
632
633
634
        if isinstance(inputs, (tuple, list)):
            labels = inputs[8] if len(inputs) > 8 else labels
            if len(inputs) > 8:
                inputs = inputs[:8]
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)
Lysandre Debut's avatar
Lysandre Debut committed
635
636

        discriminator_hidden_states = self.electra(
637
            inputs,
638
639
640
641
642
643
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
644
            output_hidden_states,
645
            training=training,
Lysandre Debut's avatar
Lysandre Debut committed
646
647
648
649
650
        )
        discriminator_sequence_output = discriminator_hidden_states[0]
        discriminator_sequence_output = self.dropout(discriminator_sequence_output)
        logits = self.classifier(discriminator_sequence_output)

Julien Plu's avatar
Julien Plu committed
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        outputs = (logits,) + discriminator_hidden_states[1:]

        if labels is not None:
            loss = self.compute_loss(labels, logits)
            outputs = (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)


@add_start_docstrings(
    """Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    ELECTRA_START_DOCSTRING,
)
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.electra = TFElectraMainLayer(config, name="electra")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

    @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
676
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
Julien Plu's avatar
Julien Plu committed
677
678
    def call(
        self,
679
        inputs=None,
Julien Plu's avatar
Julien Plu committed
680
681
682
683
684
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
685
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
686
        output_hidden_states=None,
687
688
        start_positions=None,
        end_positions=None,
Julien Plu's avatar
Julien Plu committed
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        training=False,
    ):
        r"""
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

    Return:
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
            Span-start scores (before SoftMax).
        end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
            Span-end scores (before SoftMax).
707
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Julien Plu's avatar
Julien Plu committed
708
709
710
711
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
712
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Julien Plu's avatar
Julien Plu committed
713
714
715
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

716
717
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
Julien Plu's avatar
Julien Plu committed
718
        """
719
720
721
722
723
724
725
726
727
        if isinstance(inputs, (tuple, list)):
            start_positions = inputs[8] if len(inputs) > 8 else start_positions
            end_positions = inputs[9] if len(inputs) > 9 else end_positions
            if len(inputs) > 8:
                inputs = inputs[:8]
        elif isinstance(inputs, (dict, BatchEncoding)):
            start_positions = inputs.pop("start_positions", start_positions)
            end_positions = inputs.pop("end_positions", start_positions)

Julien Plu's avatar
Julien Plu committed
728
        discriminator_hidden_states = self.electra(
729
            inputs,
730
731
732
733
734
735
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
736
            output_hidden_states,
737
            training=training,
Julien Plu's avatar
Julien Plu committed
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        )
        discriminator_sequence_output = discriminator_hidden_states[0]

        logits = self.qa_outputs(discriminator_sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]

        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
            loss = self.compute_loss(labels, outputs[:2])
            outputs = (loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)