modeling_tf_bert.py 51.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import os
import sys
from io import open

import numpy as np
import tensorflow as tf

from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel
from .file_utils import add_start_docstrings
33
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
thomwolf's avatar
thomwolf committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

logger = logging.getLogger(__name__)


TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
}


55
56
def load_bert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
    # build the network
thomwolf's avatar
thomwolf committed
57
58
    inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
    tf_inputs = tf.constant(inputs_list)
59
    tfo = tf_model(tf_inputs, training=False)
60
    return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)
thomwolf's avatar
thomwolf committed
61
62
63


def gelu(x):
thomwolf's avatar
thomwolf committed
64
65
66
67
68
69
70
71
72
73
    """ Gaussian Error Linear Unit.
    Original Implementation of the gelu activation function in Google Bert repo when initialy created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
    return x * cdf

def gelu_new(x):
thomwolf's avatar
thomwolf committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
    cdf = 0.5 * (1.0 + tf.tanh(
        (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf

def swish(x):
    return x * tf.sigmoid(x)


ACT2FN = {"gelu": tf.keras.layers.Activation(gelu),
          "relu": tf.keras.activations.relu,
thomwolf's avatar
thomwolf committed
92
93
          "swish": tf.keras.layers.Activation(swish),
          "gelu_new": tf.keras.layers.Activation(gelu_new)}
thomwolf's avatar
thomwolf committed
94
95
96
97
98
99
100


class TFBertEmbeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config, **kwargs):
        super(TFBertEmbeddings, self).__init__(**kwargs)
101
102
103
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size

thomwolf's avatar
thomwolf committed
104
105
106
107
108
109
110
111
        self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
        self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.hidden_size],
                initializer=tf.random_normal_initializer(
                    mean=0., stddev=self.hidden_size**-0.5))
        super(TFBertEmbeddings, self).build(input_shape)

    def call(self, inputs, mode="embedding", training=False):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
        
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs, training=training)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, inputs, training=False):
        """Applies embedding based on inputs tensor."""
thomwolf's avatar
thomwolf committed
148
149
150
151
152
153
154
155
        input_ids, position_ids, token_type_ids = inputs

        seq_length = tf.shape(input_ids)[1]
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
        if token_type_ids is None:
            token_type_ids = tf.fill(tf.shape(input_ids), 0)

156
        words_embeddings = tf.gather(self.word_embeddings, input_ids)
thomwolf's avatar
thomwolf committed
157
158
159
160
161
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
thomwolf's avatar
thomwolf committed
162
        embeddings = self.dropout(embeddings, training=training)
thomwolf's avatar
thomwolf committed
163
164
        return embeddings

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [batch_size, length, hidden_size]
            Returns:
                float32 tensor with shape [batch_size, length, vocab_size].
        """
        batch_size = tf.shape(inputs)[0]
        length = tf.shape(inputs)[1]

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])

thomwolf's avatar
thomwolf committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertSelfAttention, self).__init__(**kwargs)
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = tf.keras.layers.Dense(self.all_head_size, name='query')
        self.key = tf.keras.layers.Dense(self.all_head_size, name='key')
        self.value = tf.keras.layers.Dense(self.all_head_size, name='value')

        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        batch_size = tf.shape(hidden_states)[0]
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)  # (batch size, num_heads, seq_len_q, seq_len_k)
        dk = tf.cast(tf.shape(key_layer)[-1], tf.float32) # scale attention_scores
        attention_scores = attention_scores / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
221
222
223
224

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
225
226
227
228

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

thomwolf's avatar
thomwolf committed
229
230
231
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)
thomwolf's avatar
thomwolf committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)

        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
        context_layer = tf.reshape(context_layer, 
                                  (batch_size, -1, self.all_head_size))  # (batch_size, seq_len_q, all_head_size)

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs


class TFBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertSelfOutput, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
258
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertAttention, self).__init__(**kwargs)
        self.self_attention = TFBertSelfAttention(config, name='self')
        self.dense_output = TFBertSelfOutput(config, name='output')

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, inputs, training=False):
        input_tensor, attention_mask, head_mask = inputs

        self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
        attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertIntermediate, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.intermediate_size, name='dense')
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class TFBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertOutput, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
307
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertLayer, self).__init__(**kwargs)
        self.attention = TFBertAttention(config, name='attention')
        self.intermediate = TFBertIntermediate(config, name='intermediate')
        self.bert_output = TFBertOutput(config, name='output')

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.bert_output([intermediate_output, attention_output], training=training)
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertEncoder, self).__init__(**kwargs)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
335
        self.layer = [TFBertLayer(config, name='layer_._{}'.format(i)) for i in range(config.num_hidden_layers)]
thomwolf's avatar
thomwolf committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
            hidden_states = layer_outputs[0]

            if self.output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # outputs, (hidden states), (attentions)


class TFBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertPooler, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        return pooled_output


class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertPredictionHeadTransform, self).__init__(**kwargs)
        self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class TFBertLMPredictionHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
395
    def __init__(self, config, input_embeddings, **kwargs):
thomwolf's avatar
thomwolf committed
396
397
398
399
400
401
        super(TFBertLMPredictionHead, self).__init__(**kwargs)
        self.vocab_size = config.vocab_size
        self.transform = TFBertPredictionHeadTransform(config, name='transform')

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
402
        self.input_embeddings = input_embeddings
thomwolf's avatar
thomwolf committed
403
404
405
406
407
408

    def build(self, input_shape):
        self.bias = self.add_weight(shape=(self.vocab_size,),
                                    initializer='zeros',
                                    trainable=True,
                                    name='bias')
thomwolf's avatar
thomwolf committed
409
        super(TFBertLMPredictionHead, self).build(input_shape)
thomwolf's avatar
thomwolf committed
410
411
412

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
thomwolf's avatar
thomwolf committed
413
414
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
thomwolf's avatar
thomwolf committed
415
416
417
418
        return hidden_states


class TFBertMLMHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
419
    def __init__(self, config, input_embeddings, **kwargs):
thomwolf's avatar
thomwolf committed
420
        super(TFBertMLMHead, self).__init__(**kwargs)
thomwolf's avatar
thomwolf committed
421
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name='predictions')
thomwolf's avatar
thomwolf committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class TFBertNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertNSPHead, self).__init__(**kwargs)
        self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class TFBertMainLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFBertMainLayer, self).__init__(**kwargs)
        self.num_hidden_layers = config.num_hidden_layers

        self.embeddings = TFBertEmbeddings(config, name='embeddings')
        self.encoder = TFBertEncoder(config, name='encoder')
        self.pooler = TFBertPooler(config, name='pooler')

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        raise NotImplementedError

457
458
    # def call(self, input_ids, attention_mask=None, token_type_ids=None,
    #          position_ids=None, head_mask=None, training=False):
thomwolf's avatar
thomwolf committed
459
460
    def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
461
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
462
463
464
465
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
thomwolf's avatar
thomwolf committed
466
            assert len(inputs) <= 5, "Too many inputs."
thomwolf's avatar
thomwolf committed
467
        elif isinstance(inputs, dict):
thomwolf's avatar
thomwolf committed
468
            input_ids = inputs.get('input_ids')
thomwolf's avatar
thomwolf committed
469
470
471
472
            attention_mask = inputs.get('attention_mask', attention_mask)
            token_type_ids = inputs.get('token_type_ids', token_type_ids)
            position_ids = inputs.get('position_ids', position_ids)
            head_mask = inputs.get('head_mask', head_mask)
thomwolf's avatar
thomwolf committed
473
            assert len(inputs) <= 5, "Too many inputs."
thomwolf's avatar
thomwolf committed
474
475
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517

        if attention_mask is None:
            attention_mask = tf.fill(tf.shape(input_ids), 1)
        if token_type_ids is None:
            token_type_ids = tf.fill(tf.shape(input_ids), 0)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if not head_mask is None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

        embedding_output = self.embeddings([input_ids, position_ids, token_type_ids], training=training)
        encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
518

thomwolf's avatar
thomwolf committed
519
520
521
522
523
524
class TFBertPreTrainedModel(TFPreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = BertConfig
    pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
525
    load_pt_weights = load_bert_pt_weights_in_tf2
thomwolf's avatar
thomwolf committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    base_model_prefix = "bert"


BERT_START_DOCSTRING = r"""    The BERT model was proposed in
    `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
    by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
    pre-trained using a combination of masked language modeling objective and next sentence prediction
    on a large corpus comprising the Toronto Book Corpus and Wikipedia.

    This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

    .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
        https://arxiv.org/abs/1810.04805

    .. _`tf.keras.Model`:
        https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model

thomwolf's avatar
thomwolf committed
544
545
546
547
548
549
550
551
552
    Note on the model inputs:
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

        This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.

        If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
thomwolf's avatar
thomwolf committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567

        - a single Tensor with input_ids only and nothing else: `model(inputs_ids)
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
            `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
            `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`

    Parameters:
        config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model. 
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

BERT_INPUTS_DOCSTRING = r"""
    Inputs:
thomwolf's avatar
thomwolf committed
568
        **input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
                
                ``token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``
                
                ``token_type_ids:   0   0   0   0  0     0   0``

            Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

            Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
thomwolf's avatar
thomwolf committed
590
        **attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
591
592
593
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
594
        **token_type_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
595
596
597
598
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
thomwolf's avatar
thomwolf committed
599
        **position_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
600
601
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
602
        **head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
603
604
605
606
607
608
609
610
611
612
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
                      BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertModel(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
613
        **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
thomwolf's avatar
thomwolf committed
614
            Sequence of hidden-states at the output of the last layer of the model.
thomwolf's avatar
thomwolf committed
615
        **pooler_output**: ``tf.Tensor`` of shape ``(batch_size, hidden_size)``
thomwolf's avatar
thomwolf committed
616
617
618
619
620
621
622
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during Bert pretraining. This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
623
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
624
625
626
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
627
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
628
629
630
631
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
632
633
634
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertModel

thomwolf's avatar
thomwolf committed
635
636
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertModel.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
637
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
638
639
640
641
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
thomwolf's avatar
thomwolf committed
642
643
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertModel, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
644
645
        self.bert = TFBertMainLayer(config, name='bert')

thomwolf's avatar
thomwolf committed
646
647
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
648
649
650
651
652
653
654
655
656
        return outputs


@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForPreTraining(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
657
        **prediction_scores**: ```tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
thomwolf's avatar
thomwolf committed
658
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
thomwolf's avatar
thomwolf committed
659
        **seq_relationship_scores**: ```tf.Tensor`` of shape ``(batch_size, sequence_length, 2)``
thomwolf's avatar
thomwolf committed
660
661
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
662
            list of ```tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
663
664
665
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
666
            list of ```tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
667
668
669
670
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
671
672
673
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForPreTraining

thomwolf's avatar
thomwolf committed
674
675
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
676
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
677
678
679
680
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]

    """
thomwolf's avatar
thomwolf committed
681
682
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForPreTraining, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
683
684

        self.bert = TFBertMainLayer(config, name='bert')
685
686
        self.nsp = TFBertNSPHead(config, name='nsp___cls')
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
thomwolf's avatar
thomwolf committed
687

thomwolf's avatar
thomwolf committed
688
689
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
690
691

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
692
        prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
693
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
694
695
696
697
698
699
700
701
702
703
704

        outputs = (prediction_scores, seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # prediction_scores, seq_relationship_score, (hidden_states), (attentions)


@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
705
        **prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
thomwolf's avatar
thomwolf committed
706
707
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
708
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
709
710
711
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
712
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
713
714
715
716
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
717
718
719
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForMaskedLM

thomwolf's avatar
thomwolf committed
720
721
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
722
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
723
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
724
        prediction_scores = outputs[0]
thomwolf's avatar
thomwolf committed
725
726

    """
thomwolf's avatar
thomwolf committed
727
728
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
729
730

        self.bert = TFBertMainLayer(config, name='bert')
731
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name='mlm___cls')
thomwolf's avatar
thomwolf committed
732

thomwolf's avatar
thomwolf committed
733
734
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
735
736

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
737
        prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False))
thomwolf's avatar
thomwolf committed
738
739
740
741
742
743
744
745
746
747
748

        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here

        return outputs  # prediction_scores, (hidden_states), (attentions)


@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
749
        **seq_relationship_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, 2)``
thomwolf's avatar
thomwolf committed
750
751
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
752
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
753
754
755
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
756
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
757
758
759
760
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
761
762
763
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForNextSentencePrediction

thomwolf's avatar
thomwolf committed
764
765
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
766
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
767
768
769
770
        outputs = model(input_ids)
        seq_relationship_scores = outputs[0]

    """
thomwolf's avatar
thomwolf committed
771
772
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForNextSentencePrediction, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
773
774

        self.bert = TFBertMainLayer(config, name='bert')
775
        self.nsp = TFBertNSPHead(config, name='nsp___cls')
thomwolf's avatar
thomwolf committed
776

thomwolf's avatar
thomwolf committed
777
778
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
779
780

        pooled_output = outputs[1]
781
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
782
783
784
785

        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
786
787
788
789
790
791
792
793


@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForSequenceClassification(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
794
        **logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
thomwolf's avatar
thomwolf committed
795
796
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
797
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
798
799
800
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
801
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
802
803
804
805
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
806
807
808
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForSequenceClassification

thomwolf's avatar
thomwolf committed
809
810
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
811
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
812
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
813
        logits = outputs[0]
thomwolf's avatar
thomwolf committed
814
815

    """
thomwolf's avatar
thomwolf committed
816
817
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
818
819
820
821
822
823
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')

thomwolf's avatar
thomwolf committed
824
825
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
826
827
828

        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
829
        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
thomwolf's avatar
thomwolf committed
830
831
832
833
834
835
836
837
838
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # logits, (hidden_states), (attentions)


@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
thomwolf's avatar
thomwolf committed
839
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
840
841
842
class TFBertForMultipleChoice(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
843
        **classification_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
thomwolf's avatar
thomwolf committed
844
845
846
            of the input tensors. (see `input_ids` above).
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
847
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
848
849
850
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
851
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
852
853
854
855
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
856
857
858
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForMultipleChoice

thomwolf's avatar
thomwolf committed
859
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
860
        model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
861
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
thomwolf's avatar
thomwolf committed
862
        input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :]  # Batch size 1, 2 choices
thomwolf's avatar
thomwolf committed
863
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
864
        classification_scores = outputs[0]
thomwolf's avatar
thomwolf committed
865
866

    """
thomwolf's avatar
thomwolf committed
867
868
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMultipleChoice, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
869
870
871
872
873

        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(1, name='classifier')

thomwolf's avatar
thomwolf committed
874
875
    def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
876
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
877
878
879
880
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
thomwolf's avatar
thomwolf committed
881
            assert len(inputs) <= 5, "Too many inputs."
thomwolf's avatar
thomwolf committed
882
        elif isinstance(inputs, dict):
thomwolf's avatar
thomwolf committed
883
            input_ids = inputs.get('input_ids')
thomwolf's avatar
thomwolf committed
884
885
886
887
            attention_mask = inputs.get('attention_mask', attention_mask)
            token_type_ids = inputs.get('token_type_ids', token_type_ids)
            position_ids = inputs.get('position_ids', position_ids)
            head_mask = inputs.get('head_mask', head_mask)
thomwolf's avatar
thomwolf committed
888
            assert len(inputs) <= 5, "Too many inputs."
thomwolf's avatar
thomwolf committed
889
890
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905

        num_choices = tf.shape(input_ids)[1]
        seq_length = tf.shape(input_ids)[2]

        flat_input_ids = tf.reshape(input_ids, (-1, seq_length))
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

        flat_inputs = [flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask]

        outputs = self.bert(flat_inputs, training=training)

        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
906
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # reshaped_logits, (hidden_states), (attentions)


@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForTokenClassification(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
921
        **scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
thomwolf's avatar
thomwolf committed
922
923
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
924
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
925
926
927
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
928
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
929
930
931
932
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
933
934
935
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForTokenClassification

thomwolf's avatar
thomwolf committed
936
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
937
938
        model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
939
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
940
        scores = outputs[0]
thomwolf's avatar
thomwolf committed
941
942

    """
thomwolf's avatar
thomwolf committed
943
944
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForTokenClassification, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
945
946
947
948
949
950
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')

thomwolf's avatar
thomwolf committed
951
952
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
953
954
955

        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
956
        sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
thomwolf's avatar
thomwolf committed
957
958
959
960
961
962
963
964
965
966
967
968
969
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # scores, (hidden_states), (attentions)


@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
970
        **start_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
thomwolf's avatar
thomwolf committed
971
            Span-start scores (before SoftMax).
thomwolf's avatar
thomwolf committed
972
        **end_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
thomwolf's avatar
thomwolf committed
973
974
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
975
            list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
976
977
978
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
979
            list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
980
981
982
983
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
984
985
986
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFBertForQuestionAnswering

thomwolf's avatar
thomwolf committed
987
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
988
989
        model = TFBertForQuestionAnswering.from_pretrained('bert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
990
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
991
        start_scores, end_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
992
993

    """
thomwolf's avatar
thomwolf committed
994
995
    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
996
997
998
999
1000
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name='bert')
        self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')

thomwolf's avatar
thomwolf committed
1001
1002
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        outputs = (start_logits, end_logits,) + outputs[2:]

        return outputs  # start_logits, end_logits, (hidden_states), (attentions)