"test/srt/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "7d9679b74d08341bbe871d1470680fb712da3242"
modeling_tf_mobilebert.py 64 KB
Newer Older
Vasily Shamporov's avatar
Vasily Shamporov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 MobileBERT model. """


Sylvain Gugger's avatar
Sylvain Gugger committed
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Vasily Shamporov's avatar
Vasily Shamporov committed
21
22
23
24

import tensorflow as tf

from . import MobileBertConfig
25
from .activations_tf import get_tf_activation
26
27
from .file_utils import (
    MULTIPLE_CHOICE_DUMMY_INPUTS,
Sylvain Gugger's avatar
Sylvain Gugger committed
28
    ModelOutput,
29
30
31
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
Sylvain Gugger's avatar
Sylvain Gugger committed
32
    replace_return_docstrings,
33
)
34
from .modeling_tf_bert import TFBertIntermediate
Sylvain Gugger's avatar
Sylvain Gugger committed
35
36
37
38
39
40
41
42
43
44
from .modeling_tf_outputs import (
    TFBaseModelOutput,
    TFBaseModelOutputWithPooling,
    TFMaskedLMOutput,
    TFMultipleChoiceModelOutput,
    TFNextSentencePredictorOutput,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
)
Vasily Shamporov's avatar
Vasily Shamporov committed
45
from .modeling_tf_utils import (
46
    TFMaskedLanguageModelingLoss,
Vasily Shamporov's avatar
Vasily Shamporov committed
47
48
49
50
51
52
53
54
55
56
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
from .tokenization_utils import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
57
from .utils import logging
Vasily Shamporov's avatar
Vasily Shamporov committed
58
59


Lysandre Debut's avatar
Lysandre Debut committed
60
logger = logging.get_logger(__name__)
Vasily Shamporov's avatar
Vasily Shamporov committed
61

Sylvain Gugger's avatar
Sylvain Gugger committed
62
_CONFIG_FOR_DOC = "MobileBertConfig"
63
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
Vasily Shamporov's avatar
Vasily Shamporov committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "mobilebert-uncased",
    # See all MobileBERT models at https://huggingface.co/models?filter=mobilebert
]


class TFLayerNorm(tf.keras.layers.LayerNormalization):
    def __init__(self, feat_size, *args, **kwargs):
        super().__init__(*args, **kwargs)


class TFNoNorm(tf.keras.layers.Layer):
    def __init__(self, feat_size, epsilon=None, **kwargs):
        super().__init__(**kwargs)
        self.feat_size = feat_size

    def build(self, input_shape):
        self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros")
        self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones")

    def call(self, inputs: tf.Tensor):
        return inputs * self.weight + self.bias


NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm}


class TFMobileBertEmbeddings(tf.keras.layers.Layer):
Lysandre's avatar
Lysandre committed
93
    """Construct the embeddings from word, position and token_type embeddings."""
Vasily Shamporov's avatar
Vasily Shamporov committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.trigram_input = config.trigram_input
        self.embedding_size = config.embedding_size
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
        self.initializer_range = config.initializer_range

        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )

        self.embedding_transformation = tf.keras.layers.Dense(config.hidden_size, name="embedding_transformation")

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )
        super().build(input_shape)

Julien Plu's avatar
Julien Plu committed
137
138
139
140
141
142
143
144
145
    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        inputs_embeds=None,
        mode="embedding",
        training=False,
    ):
Vasily Shamporov's avatar
Vasily Shamporov committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.

        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
Julien Plu's avatar
Julien Plu committed
161
            return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
Vasily Shamporov's avatar
Vasily Shamporov committed
162
        elif mode == "linear":
Julien Plu's avatar
Julien Plu committed
163
            return self._linear(input_ids)
Vasily Shamporov's avatar
Vasily Shamporov committed
164
165
166
        else:
            raise ValueError("mode {} is not valid.".format(mode))

Julien Plu's avatar
Julien Plu committed
167
    def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
Vasily Shamporov's avatar
Vasily Shamporov committed
168
        """Applies embedding based on inputs tensor."""
Julien Plu's avatar
Julien Plu committed
169
        assert not (input_ids is None and inputs_embeds is None)
Vasily Shamporov's avatar
Vasily Shamporov committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

        if input_ids is not None:
            input_shape = shape_list(input_ids)
        else:
            input_shape = shape_list(inputs_embeds)[:-1]

        seq_length = input_shape[1]
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)

        if self.trigram_input:
            # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
            # Devices (https://arxiv.org/abs/2004.02984)
            #
            # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
            # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
            # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
            # dimensional output.
            inputs_embeds = tf.concat(
                [
                    tf.pad(inputs_embeds[:, 1:], ((0, 0), (0, 1), (0, 0))),
                    inputs_embeds,
                    tf.pad(inputs_embeds[:, :-1], ((0, 0), (1, 0), (0, 0))),
                ],
                axis=2,
            )

        if self.trigram_input or self.embedding_size != self.hidden_size:
            inputs_embeds = self.embedding_transformation(inputs_embeds)

        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings, training=training)
Julien Plu's avatar
Julien Plu committed
211

Vasily Shamporov's avatar
Vasily Shamporov committed
212
213
214
215
        return embeddings

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
Lysandre's avatar
Lysandre committed
216
217
218
219
        Args:
            inputs: A float32 tensor with shape [batch_size, length, hidden_size]
        Returns:
            float32 tensor with shape [batch_size, length, vocab_size].
Vasily Shamporov's avatar
Vasily Shamporov committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        """
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])


class TFMobileBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
Julien Plu's avatar
Julien Plu committed
240
        self.output_attentions = config.output_attentions
Vasily Shamporov's avatar
Vasily Shamporov committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )

        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
        return tf.transpose(x, perm=[0, 2, 1, 3])

Julien Plu's avatar
Julien Plu committed
261
262
263
    def call(
        self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False
    ):
Vasily Shamporov's avatar
Vasily Shamporov committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        batch_size = shape_list(attention_mask)[0]
        mixed_query_layer = self.query(query_tensor)
        mixed_key_layer = self.key(key_tensor)
        mixed_value_layer = self.value(value_tensor)
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
        dk = tf.cast(shape_list(key_layer)[-1], tf.float32)  # scale attention_scores
        attention_scores = attention_scores / tf.math.sqrt(dk)

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)

        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)

Julien Plu's avatar
Julien Plu committed
301
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
Vasily Shamporov's avatar
Vasily Shamporov committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        return outputs


class TFMobileBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.use_bottleneck = config.use_bottleneck
        self.dense = tf.keras.layers.Dense(
            config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )
        if not self.use_bottleneck:
            self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
319
    def call(self, hidden_states, residual_tensor, training=False):
Vasily Shamporov's avatar
Vasily Shamporov committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        hidden_states = self.dense(hidden_states)
        if not self.use_bottleneck:
            hidden_states = self.dropout(hidden_states, training=training)
        hidden_states = self.LayerNorm(hidden_states + residual_tensor)
        return hidden_states


class TFMobileBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.self = TFMobileBertSelfAttention(config, name="self")
        self.mobilebert_output = TFMobileBertSelfOutput(config, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

Julien Plu's avatar
Julien Plu committed
336
337
338
339
340
341
342
343
344
345
346
    def call(
        self,
        query_tensor,
        key_tensor,
        value_tensor,
        layer_input,
        attention_mask,
        head_mask,
        output_attentions,
        training=False,
    ):
Vasily Shamporov's avatar
Vasily Shamporov committed
347
        self_outputs = self.self(
Julien Plu's avatar
Julien Plu committed
348
            query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training
Vasily Shamporov's avatar
Vasily Shamporov committed
349
        )
Julien Plu's avatar
Julien Plu committed
350
351

        attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training)
Vasily Shamporov's avatar
Vasily Shamporov committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class TFMobileBertIntermediate(TFBertIntermediate):
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.dense = tf.keras.layers.Dense(config.intermediate_size, name="dense")


class TFOutputBottleneck(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
371
    def call(self, hidden_states, residual_tensor, training=False):
Vasily Shamporov's avatar
Vasily Shamporov committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        layer_outputs = self.dense(hidden_states)
        layer_outputs = self.dropout(layer_outputs, training=training)
        layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
        return layer_outputs


class TFMobileBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.use_bottleneck = config.use_bottleneck
        self.dense = tf.keras.layers.Dense(
            config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )
        if not self.use_bottleneck:
            self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        else:
            self.bottleneck = TFOutputBottleneck(config, name="bottleneck")

Julien Plu's avatar
Julien Plu committed
393
    def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
Vasily Shamporov's avatar
Vasily Shamporov committed
394
395
        hidden_states = self.dense(hidden_states)
        if not self.use_bottleneck:
396
            hidden_states = self.dropout(hidden_states, training=training)
Vasily Shamporov's avatar
Vasily Shamporov committed
397
398
399
            hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
        else:
            hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
Julien Plu's avatar
Julien Plu committed
400
            hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
Vasily Shamporov's avatar
Vasily Shamporov committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        return hidden_states


class TFBottleneckLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.intra_bottleneck_size, name="dense")
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.intra_bottleneck_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )

    def call(self, inputs):
        hidden_states = self.dense(inputs)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class TFBottleneck(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
        self.use_bottleneck_attention = config.use_bottleneck_attention
        self.bottleneck_input = TFBottleneckLayer(config, name="input")
        if self.key_query_shared_bottleneck:
            self.attention = TFBottleneckLayer(config, name="attention")

    def call(self, hidden_states):
        # This method can return three different tuples of values. These different values make use of bottlenecks,
        # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
        # usage. These linear layer have weights that are learned during training.
        #
        # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the
        # key, query, value, and "layer input" to be used by the attention layer.
        # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor
        # in the attention self output, after the attention scores have been computed.
        #
        # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return
        # four values, three of which have been passed through a bottleneck: the query and key, passed through the same
        # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck.
        #
        # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck,
        # and the residual layer will be this value passed through a bottleneck.

        bottlenecked_hidden_states = self.bottleneck_input(hidden_states)
        if self.use_bottleneck_attention:
            return (bottlenecked_hidden_states,) * 4
        elif self.key_query_shared_bottleneck:
            shared_attention_input = self.attention(hidden_states)
            return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
        else:
            return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)


class TFFFNOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.true_hidden_size, name="dense")
        self.LayerNorm = NORM2FN[config.normalization_type](
            config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
        )

    def call(self, hidden_states, residual_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + residual_tensor)
        return hidden_states


class TFFFNLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
        self.mobilebert_output = TFFFNOutput(config, name="output")

    def call(self, hidden_states):
        intermediate_output = self.intermediate(hidden_states)
        layer_outputs = self.mobilebert_output(intermediate_output, hidden_states)
        return layer_outputs


class TFMobileBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.use_bottleneck = config.use_bottleneck
        self.num_feedforward_networks = config.num_feedforward_networks
        self.attention = TFMobileBertAttention(config, name="attention")
        self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
        self.mobilebert_output = TFMobileBertOutput(config, name="output")

        if self.use_bottleneck:
            self.bottleneck = TFBottleneck(config, name="bottleneck")
        if config.num_feedforward_networks > 1:
            self.ffn = [
                TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1)
            ]

Julien Plu's avatar
Julien Plu committed
496
    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
Vasily Shamporov's avatar
Vasily Shamporov committed
497
498
499
500
501
502
        if self.use_bottleneck:
            query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
        else:
            query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4

        attention_outputs = self.attention(
Julien Plu's avatar
Julien Plu committed
503
504
505
506
507
508
509
            query_tensor,
            key_tensor,
            value_tensor,
            layer_input,
            attention_mask,
            head_mask,
            output_attentions,
Vasily Shamporov's avatar
Vasily Shamporov committed
510
511
512
513
514
515
516
517
518
519
520
521
            training=training,
        )

        attention_output = attention_outputs[0]
        s = (attention_output,)

        if self.num_feedforward_networks != 1:
            for i, ffn_module in enumerate(self.ffn):
                attention_output = ffn_module(attention_output)
                s += (attention_output,)

        intermediate_output = self.intermediate(attention_output)
Julien Plu's avatar
Julien Plu committed
522
523
        layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training)

Vasily Shamporov's avatar
Vasily Shamporov committed
524
525
526
        outputs = (
            (layer_output,)
            + attention_outputs[1:]
Julien Plu's avatar
Julien Plu committed
527
528
529
530
531
532
533
534
535
            + (
                tf.constant(0),
                query_tensor,
                key_tensor,
                value_tensor,
                layer_input,
                attention_output,
                intermediate_output,
            )
Vasily Shamporov's avatar
Vasily Shamporov committed
536
537
            + s
        )  # add attentions if we output them
Julien Plu's avatar
Julien Plu committed
538

Vasily Shamporov's avatar
Vasily Shamporov committed
539
540
541
542
543
544
        return outputs


class TFMobileBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
Julien Plu's avatar
Julien Plu committed
545
546
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
Vasily Shamporov's avatar
Vasily Shamporov committed
547
548
        self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]

Sylvain Gugger's avatar
Sylvain Gugger committed
549
550
551
552
553
554
555
556
557
558
559
560
    def call(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
Vasily Shamporov's avatar
Vasily Shamporov committed
561
        for i, layer_module in enumerate(self.layer):
Julien Plu's avatar
Julien Plu committed
562
            if output_hidden_states:
Vasily Shamporov's avatar
Vasily Shamporov committed
563
564
565
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
Julien Plu's avatar
Julien Plu committed
566
                hidden_states, attention_mask, head_mask[i], output_attentions, training=training
Vasily Shamporov's avatar
Vasily Shamporov committed
567
            )
Julien Plu's avatar
Julien Plu committed
568

Vasily Shamporov's avatar
Vasily Shamporov committed
569
570
            hidden_states = layer_outputs[0]

Julien Plu's avatar
Julien Plu committed
571
            if output_attentions:
Vasily Shamporov's avatar
Vasily Shamporov committed
572
573
574
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
Julien Plu's avatar
Julien Plu committed
575
        if output_hidden_states:
Vasily Shamporov's avatar
Vasily Shamporov committed
576
577
            all_hidden_states = all_hidden_states + (hidden_states,)

Sylvain Gugger's avatar
Sylvain Gugger committed
578
579
580
581
582
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614


class TFMobileBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.do_activate = config.classifier_activation
        if self.do_activate:
            self.dense = tf.keras.layers.Dense(
                config.hidden_size,
                kernel_initializer=get_initializer(config.initializer_range),
                activation="tanh",
                name="dense",
            )

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        if not self.do_activate:
            return first_token_tensor
        else:
            pooled_output = self.dense(first_token_tensor)
            return pooled_output


class TFMobileBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        if isinstance(config.hidden_act, str):
615
            self.transform_act_fn = get_tf_activation(config.hidden_act)
Vasily Shamporov's avatar
Vasily Shamporov committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm")

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class TFMobileBertLMPredictionHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.transform = TFMobileBertPredictionHeadTransform(config, name="transform")
        self.vocab_size = config.vocab_size
        self.config = config

    def build(self, input_shape):
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
        self.dense = self.add_weight(
            shape=(self.config.hidden_size - self.config.embedding_size, self.vocab_size),
            initializer="zeros",
            trainable=True,
            name="dense/weight",
        )
        self.decoder = self.add_weight(
            shape=(self.config.vocab_size, self.config.embedding_size),
            initializer="zeros",
            trainable=True,
            name="decoder/weight",
        )
        super().build(input_shape)

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0))
        hidden_states = hidden_states + self.bias
        return hidden_states


class TFMobileBertMLMHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.predictions = TFMobileBertLMPredictionHead(config, name="predictions")

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


@keras_serializable
class TFMobileBertMainLayer(tf.keras.layers.Layer):
    config_class = MobileBertConfig

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.num_hidden_layers = config.num_hidden_layers
        self.output_attentions = config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
675
        self.output_hidden_states = config.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
676
        self.return_dict = config.use_return_dict
Vasily Shamporov's avatar
Vasily Shamporov committed
677
678
679
680
681
682
683
684
685
686
687
688

        self.embeddings = TFMobileBertEmbeddings(config, name="embeddings")
        self.encoder = TFMobileBertEncoder(config, name="encoder")
        self.pooler = TFMobileBertPooler(config, name="pooler")

    def get_input_embeddings(self):
        return self.embeddings

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
Lysandre's avatar
Lysandre committed
689
690
691
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        See base class PreTrainedModel
Vasily Shamporov's avatar
Vasily Shamporov committed
692
693
694
695
696
697
698
699
700
701
702
703
        """
        raise NotImplementedError

    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
704
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
705
        return_dict=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
706
707
708
709
710
711
712
713
714
715
        training=False,
    ):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
Joseph Liu's avatar
Joseph Liu committed
716
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
717
718
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            assert len(inputs) <= 9, "Too many inputs."
Vasily Shamporov's avatar
Vasily Shamporov committed
719
720
721
722
723
724
725
726
        elif isinstance(inputs, (dict, BatchEncoding)):
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
727
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
728
729
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 9, "Too many inputs."
Vasily Shamporov's avatar
Vasily Shamporov committed
730
731
732
733
        else:
            input_ids = inputs

        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
Joseph Liu's avatar
Joseph Liu committed
734
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
735
        return_dict = return_dict if return_dict is not None else self.return_dict
Vasily Shamporov's avatar
Vasily Shamporov committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)
        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers

Julien Plu's avatar
Julien Plu committed
777
        embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
Vasily Shamporov's avatar
Vasily Shamporov committed
778
        encoder_outputs = self.encoder(
Julien Plu's avatar
Julien Plu committed
779
780
781
782
783
            embedding_output,
            extended_attention_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
784
            return_dict,
Joseph Liu's avatar
Joseph Liu committed
785
            training=training,
Vasily Shamporov's avatar
Vasily Shamporov committed
786
787
788
789
790
        )

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
791
        if not return_dict:
Lysandre's avatar
Lysandre committed
792
793
794
795
            return (
                sequence_output,
                pooled_output,
            ) + encoder_outputs[1:]
Sylvain Gugger's avatar
Sylvain Gugger committed
796
797
798
799
800
801
802

        return TFBaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
803
804
805


class TFMobileBertPreTrainedModel(TFPreTrainedModel):
Lysandre's avatar
Lysandre committed
806
807
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
Vasily Shamporov's avatar
Vasily Shamporov committed
808
809
810
811
812
813
    """

    config_class = MobileBertConfig
    base_model_prefix = "mobilebert"


Sylvain Gugger's avatar
Sylvain Gugger committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
@dataclass
class TFMobileBertForPreTrainingOutput(ModelOutput):
    """
    Output type of :class:`~transformers.TFMobileBertForPreTrainingModel`.

    Args:
        prediction_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[tf.Tensor] = None
    prediction_logits: tf.Tensor = None
    seq_relationship_logits: tf.Tensor = None
    hidden_states: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[tf.Tensor]] = None


Vasily Shamporov's avatar
Vasily Shamporov committed
845
MOBILEBERT_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
846
847
848
849
850
851
852
853

    This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
    generic methods the library implements for all its model (such as downloading or saving, resizing the input
    embeddings, pruning heads etc.)

    This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass.
    Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general
    usage and behavior.
Vasily Shamporov's avatar
Vasily Shamporov committed
854
855
856
857
858

    .. note::

        TF 2.0 models accepts two formats as inputs:

Sylvain Gugger's avatar
Sylvain Gugger committed
859
860
        - having all inputs as keyword arguments (like PyTorch models), or
        - having all inputs as a list, tuple or dict in the first positional arguments.
Vasily Shamporov's avatar
Vasily Shamporov committed
861

Sylvain Gugger's avatar
Sylvain Gugger committed
862
        This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having
Vasily Shamporov's avatar
Vasily Shamporov committed
863
864
865
866
867
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.

        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
        in the first positional argument :

Sylvain Gugger's avatar
Sylvain Gugger committed
868
        - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
Vasily Shamporov's avatar
Vasily Shamporov committed
869
870
871
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
Sylvain Gugger's avatar
Sylvain Gugger committed
872
          :obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Vasily Shamporov's avatar
Vasily Shamporov committed
873
874
875
876
877
878
879
880
881

    Parameters:
        config (:class:`~transformers.MobileBertConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

MOBILEBERT_INPUTS_DOCSTRING = r"""
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
882
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`):
Vasily Shamporov's avatar
Vasily Shamporov committed
883
884
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
885
886
887
            Indices can be obtained using :class:`~transformers.MobileBertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.__call__` and
            :func:`transformers.PreTrainedTokenizer.encode` for details.
Vasily Shamporov's avatar
Vasily Shamporov committed
888
889

            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
890
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
891
892
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
893
894
895

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **maked**.
Vasily Shamporov's avatar
Vasily Shamporov committed
896
897

            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
898
        token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
899
            Segment token indices to indicate first and second portions of the inputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
900
901
902
903
            Indices are selected in ``[0, 1]``:

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Vasily Shamporov's avatar
Vasily Shamporov committed
904
905

            `What are token type IDs? <../glossary.html#token-type-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
906
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
907
908
909
910
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.

            `What are position IDs? <../glossary.html#position-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
911
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
912
913
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
914
915
916
917
918
919
920
921

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
922
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
923
924
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
925
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
926
927
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
928
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
929
930
931
932
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
        training (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
Vasily Shamporov's avatar
Vasily Shamporov committed
933
934
935
936
937
938
939
940
941
942
943
944
"""


@add_start_docstrings(
    "The bare MobileBert Model transformer outputing raw hidden-states without any specific head on top.",
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertModel(TFMobileBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")

Sylvain Gugger's avatar
Sylvain Gugger committed
945
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
946
947
948
949
950
951
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFBaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
Vasily Shamporov's avatar
Vasily Shamporov committed
952
953
954
955
956
957
958
959
960
961
962
963
964
965
    def call(self, inputs, **kwargs):
        outputs = self.mobilebert(inputs, **kwargs)
        return outputs


@add_start_docstrings(
    """MobileBert Model with two heads on top as done during the pre-training:
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
966
967
        self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
        self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls")
Vasily Shamporov's avatar
Vasily Shamporov committed
968
969
970
971

    def get_output_embeddings(self):
        return self.mobilebert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
972
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
973
    @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Vasily Shamporov's avatar
Vasily Shamporov committed
974
975
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
976
        Return:
Vasily Shamporov's avatar
Vasily Shamporov committed
977

Lysandre's avatar
Lysandre committed
978
        Examples::
Vasily Shamporov's avatar
Vasily Shamporov committed
979

Lysandre's avatar
Lysandre committed
980
981
            >>> import tensorflow as tf
            >>> from transformers import MobileBertTokenizer, TFMobileBertForPreTraining
Vasily Shamporov's avatar
Vasily Shamporov committed
982

Lysandre's avatar
Lysandre committed
983
984
985
986
987
            >>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
            >>> model = TFMobileBertForPreTraining.from_pretrained('google/mobilebert-uncased')
            >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
            >>> outputs = model(input_ids)
            >>> prediction_scores, seq_relationship_scores = outputs[:2]
Vasily Shamporov's avatar
Vasily Shamporov committed
988
989

        """
Sylvain Gugger's avatar
Sylvain Gugger committed
990
991
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
Vasily Shamporov's avatar
Vasily Shamporov committed
992
993
994
        outputs = self.mobilebert(inputs, **kwargs)

        sequence_output, pooled_output = outputs[:2]
995
996
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
Vasily Shamporov's avatar
Vasily Shamporov committed
997

Sylvain Gugger's avatar
Sylvain Gugger committed
998
999
1000
1001
1002
1003
1004
1005
1006
        if not return_dict:
            return (prediction_scores, seq_relationship_score) + outputs[2:]

        return TFMobileBertForPreTrainingOutput(
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1007
1008
1009


@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
1010
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
Vasily Shamporov's avatar
Vasily Shamporov committed
1011
1012
1013
1014
1015
1016
1017
1018
1019
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
        self.mlm = TFMobileBertMLMHead(config, name="mlm___cls")

    def get_output_embeddings(self):
        return self.mobilebert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
1020
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1021
1022
1023
1024
1025
1026
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    def call(
        self,
        inputs=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1037
        return_dict=None,
1038
1039
1040
        labels=None,
        training=False,
    ):
Vasily Shamporov's avatar
Vasily Shamporov committed
1041
        r"""
1042
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1043
1044
1045
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Vasily Shamporov's avatar
Vasily Shamporov committed
1046
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1047
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
1048
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1049
1050
1051
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.mobilebert(
            inputs,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1064
            return_dict=return_dict,
1065
1066
            training=training,
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1067
1068

        sequence_output = outputs[0]
1069
        prediction_scores = self.mlm(sequence_output, training=training)
Vasily Shamporov's avatar
Vasily Shamporov committed
1070

Sylvain Gugger's avatar
Sylvain Gugger committed
1071
        loss = None if labels is None else self.compute_loss(labels, prediction_scores)
Vasily Shamporov's avatar
Vasily Shamporov committed
1072

Sylvain Gugger's avatar
Sylvain Gugger committed
1073
1074
1075
1076
1077
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFMaskedLMOutput(
Lysandre's avatar
Lysandre committed
1078
1079
1080
1081
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1082
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103


class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.seq_relationship = tf.keras.layers.Dense(2, name="seq_relationship")

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


@add_start_docstrings(
    """MobileBert Model with a `next sentence prediction (classification)` head on top. """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
1104
        self.cls = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls")
Vasily Shamporov's avatar
Vasily Shamporov committed
1105

Sylvain Gugger's avatar
Sylvain Gugger committed
1106
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1107
    @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
Vasily Shamporov's avatar
Vasily Shamporov committed
1108
1109
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
1110
        Return:
Vasily Shamporov's avatar
Vasily Shamporov committed
1111

Lysandre's avatar
Lysandre committed
1112
        Examples::
Vasily Shamporov's avatar
Vasily Shamporov committed
1113

Lysandre's avatar
Lysandre committed
1114
1115
            >>> import tensorflow as tf
            >>> from transformers import MobileBertTokenizer, TFMobileBertForNextSentencePrediction
Vasily Shamporov's avatar
Vasily Shamporov committed
1116

Lysandre's avatar
Lysandre committed
1117
1118
            >>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
            >>> model = TFMobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased')
Vasily Shamporov's avatar
Vasily Shamporov committed
1119

Lysandre's avatar
Lysandre committed
1120
1121
1122
            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
            >>> encoding = tokenizer(prompt, next_sentence, return_tensors='tf')
Vasily Shamporov's avatar
Vasily Shamporov committed
1123

Lysandre's avatar
Lysandre committed
1124
            >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
Vasily Shamporov's avatar
Vasily Shamporov committed
1125
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1126
1127
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
Vasily Shamporov's avatar
Vasily Shamporov committed
1128
1129
1130
1131
1132
        outputs = self.mobilebert(inputs, **kwargs)

        pooled_output = outputs[1]
        seq_relationship_score = self.cls(pooled_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
1133
1134
        if not return_dict:
            return (seq_relationship_score,) + outputs[2:]
Vasily Shamporov's avatar
Vasily Shamporov committed
1135

Sylvain Gugger's avatar
Sylvain Gugger committed
1136
        return TFNextSentencePredictorOutput(
Lysandre's avatar
Lysandre committed
1137
1138
1139
            logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1140
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158


@add_start_docstrings(
    """MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks. """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1159
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1160
1161
1162
1163
1164
1165
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Vasily Shamporov's avatar
Vasily Shamporov committed
1166
1167
    def call(
        self,
1168
        inputs=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1169
1170
1171
1172
1173
1174
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1175
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1176
        return_dict=None,
1177
        labels=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1178
1179
1180
        training=False,
    ):
        r"""
1181
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
1182
1183
1184
1185
1186
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1187
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
1188
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1189
1190
1191
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1192
1193
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)
Vasily Shamporov's avatar
Vasily Shamporov committed
1194
1195

        outputs = self.mobilebert(
1196
            inputs,
Vasily Shamporov's avatar
Vasily Shamporov committed
1197
1198
1199
1200
1201
1202
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1203
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1204
            return_dict=return_dict,
Vasily Shamporov's avatar
Vasily Shamporov committed
1205
1206
1207
1208
1209
1210
1211
1212
            training=training,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output, training=training)
        logits = self.classifier(pooled_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
1213
        loss = None if labels is None else self.compute_loss(labels, logits)
Vasily Shamporov's avatar
Vasily Shamporov committed
1214

Sylvain Gugger's avatar
Sylvain Gugger committed
1215
1216
1217
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Vasily Shamporov's avatar
Vasily Shamporov committed
1218

Sylvain Gugger's avatar
Sylvain Gugger committed
1219
        return TFSequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1220
1221
1222
1223
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1224
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241


@add_start_docstrings(
    """MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1242
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1243
1244
1245
1246
1247
1248
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Vasily Shamporov's avatar
Vasily Shamporov committed
1249
1250
    def call(
        self,
1251
        inputs=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1252
1253
1254
1255
1256
1257
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1258
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1259
        return_dict=None,
1260
1261
        start_positions=None,
        end_positions=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1262
1263
1264
        training=False,
    ):
        r"""
1265
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
1266
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1267
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Vasily Shamporov's avatar
Vasily Shamporov committed
1268
            Position outside of the sequence are not taken into account for computing the loss.
1269
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
1270
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1271
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Vasily Shamporov's avatar
Vasily Shamporov committed
1272
1273
            Position outside of the sequence are not taken into account for computing the loss.
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1274
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
1275
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1276
1277
1278
1279
            start_positions = inputs[9] if len(inputs) > 9 else start_positions
            end_positions = inputs[10] if len(inputs) > 10 else end_positions
            if len(inputs) > 9:
                inputs = inputs[:9]
1280
1281
1282
1283
        elif isinstance(inputs, (dict, BatchEncoding)):
            start_positions = inputs.pop("start_positions", start_positions)
            end_positions = inputs.pop("end_positions", start_positions)

Vasily Shamporov's avatar
Vasily Shamporov committed
1284
        outputs = self.mobilebert(
1285
            inputs,
Vasily Shamporov's avatar
Vasily Shamporov committed
1286
1287
1288
1289
1290
1291
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1292
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1293
            return_dict=return_dict,
Vasily Shamporov's avatar
Vasily Shamporov committed
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
            training=training,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

Sylvain Gugger's avatar
Sylvain Gugger committed
1304
        loss = None
Vasily Shamporov's avatar
Vasily Shamporov committed
1305
1306
1307
        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
Sylvain Gugger's avatar
Sylvain Gugger committed
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
            loss = self.compute_loss(labels, (start_logits, end_logits))

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFQuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339


@add_start_docstrings(
    """MobileBert Model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

    @property
    def dummy_inputs(self):
Lysandre's avatar
Lysandre committed
1340
        """Dummy inputs to build the network.
Vasily Shamporov's avatar
Vasily Shamporov committed
1341
1342
1343
1344
1345
1346

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}

Sylvain Gugger's avatar
Sylvain Gugger committed
1347
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1348
1349
1350
1351
1352
1353
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFMultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Vasily Shamporov's avatar
Vasily Shamporov committed
1354
1355
1356
1357
1358
1359
1360
1361
1362
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1363
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1364
        return_dict=None,
1365
        labels=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1366
1367
1368
        training=False,
    ):
        r"""
1369
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
1370
            Labels for computing the multiple choice classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1371
1372
            Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
            of the input tensors. (See :obj:`input_ids` above)
Vasily Shamporov's avatar
Vasily Shamporov committed
1373
1374
1375
1376
1377
1378
1379
1380
1381
        """
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
Joseph Liu's avatar
Joseph Liu committed
1382
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
1383
1384
1385
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            labels = inputs[9] if len(inputs) > 9 else labels
            assert len(inputs) <= 10, "Too many inputs."
Vasily Shamporov's avatar
Vasily Shamporov committed
1386
1387
1388
1389
1390
1391
1392
1393
        elif isinstance(inputs, (dict, BatchEncoding)):
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
1394
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
1395
            return_dict = inputs.get("return_dict", return_dict)
1396
            labels = inputs.get("labels", labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
1397
            assert len(inputs) <= 10, "Too many inputs."
Vasily Shamporov's avatar
Vasily Shamporov committed
1398
1399
        else:
            input_ids = inputs
Sylvain Gugger's avatar
Sylvain Gugger committed
1400
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
Vasily Shamporov's avatar
Vasily Shamporov committed
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417

        if input_ids is not None:
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
        else:
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]

        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
        flat_inputs_embeds = (
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
            if inputs_embeds is not None
            else None
        )
Julien Plu's avatar
Julien Plu committed
1418
        outputs = self.mobilebert(
Vasily Shamporov's avatar
Vasily Shamporov committed
1419
1420
1421
1422
1423
1424
1425
            flat_input_ids,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            flat_inputs_embeds,
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1426
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1427
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1428
1429
            training=training,
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1430
1431
1432
1433
1434
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, training=training)
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

Sylvain Gugger's avatar
Sylvain Gugger committed
1435
        loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
Vasily Shamporov's avatar
Vasily Shamporov committed
1436

Sylvain Gugger's avatar
Sylvain Gugger committed
1437
1438
1439
1440
1441
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFMultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1442
1443
1444
1445
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1446
        )
Vasily Shamporov's avatar
Vasily Shamporov committed
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464


@add_start_docstrings(
    """MobileBert Model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1465
    @add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1466
1467
1468
1469
1470
1471
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="google/mobilebert-uncased",
        output_type=TFTokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Vasily Shamporov's avatar
Vasily Shamporov committed
1472
1473
    def call(
        self,
1474
        inputs=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1475
1476
1477
1478
1479
1480
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1481
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1482
        return_dict=None,
1483
        labels=None,
Vasily Shamporov's avatar
Vasily Shamporov committed
1484
1485
1486
        training=False,
    ):
        r"""
1487
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Vasily Shamporov's avatar
Vasily Shamporov committed
1488
1489
1490
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1491
        return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
1492
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1493
1494
1495
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1496
1497
1498
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

Vasily Shamporov's avatar
Vasily Shamporov committed
1499
        outputs = self.mobilebert(
1500
            inputs,
Vasily Shamporov's avatar
Vasily Shamporov committed
1501
1502
1503
1504
1505
1506
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1507
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1508
            return_dict=return_dict,
Vasily Shamporov's avatar
Vasily Shamporov committed
1509
1510
1511
1512
1513
1514
1515
1516
            training=training,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output, training=training)
        logits = self.classifier(sequence_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
1517
        loss = None if labels is None else self.compute_loss(labels, logits)
Vasily Shamporov's avatar
Vasily Shamporov committed
1518

Sylvain Gugger's avatar
Sylvain Gugger committed
1519
1520
1521
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Vasily Shamporov's avatar
Vasily Shamporov committed
1522

Sylvain Gugger's avatar
Sylvain Gugger committed
1523
        return TFTokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1524
1525
1526
1527
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1528
        )