modeling_tf_distilbert.py 34.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 DistilBERT model
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import copy
import sys
from io import open

import itertools

import numpy as np
import tensorflow as tf

from .configuration_distilbert import DistilBertConfig
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model

logger = logging.getLogger(__name__)


TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-tf_model.h5",
    'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-tf_model.h5"
}


### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
def gelu(x):
    """ Gaussian Error Linear Unit.
    Original Implementation of the gelu activation function in Google Bert repo when initialy created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
    return x * cdf

def gelu_new(x):
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
    cdf = 0.5 * (1.0 + tf.tanh(
        (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf

def load_distilbert_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
    # build the network
    inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
    attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
    tf_inputs = [inputs_list, attns_list]
    tfo = tf_model(tf_inputs, training=False)
    return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs)

class TFEmbeddings(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFEmbeddings, self).__init__(**kwargs)
        self.vocab_size = config.vocab_size
        self.dim = config.dim
thomwolf's avatar
thomwolf committed
82
        self.word_embeddings = TFSharedEmbeddings(config.vocab_size, config.dim, name='word_embeddings')  # padding_idx=0)
thomwolf's avatar
thomwolf committed
83
        self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.dim, name='position_embeddings')
thomwolf's avatar
thomwolf committed
84
        if config.sinusoidal_pos_embds:
thomwolf's avatar
thomwolf committed
85
86
87
88
89
90
91
92
93
94
95
96
            raise NotImplementedError

        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.dropout)

    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
thomwolf's avatar
thomwolf committed
97
                shape=[self.vocab_size, self.dim],
thomwolf's avatar
thomwolf committed
98
                initializer=tf.random_normal_initializer(
thomwolf's avatar
thomwolf committed
99
                    mean=0., stddev=self.dim**-0.5))
thomwolf's avatar
thomwolf committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        super(TFEmbeddings, self).build(input_shape)

    def call(self, inputs, mode="embedding", training=False):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
        
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs, training=training)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, inputs, training=False):
        """
        Parameters
        ----------
thomwolf's avatar
thomwolf committed
128
        input_ids: tf.Tensor(bs, max_seq_length)
thomwolf's avatar
thomwolf committed
129
130
131
132
            The token ids to embed.

        Outputs
        -------
thomwolf's avatar
thomwolf committed
133
        embeddings: tf.Tensor(bs, max_seq_length, dim)
thomwolf's avatar
thomwolf committed
134
135
            The embedded tokens (plus position embeddings, no token_type embeddings)
        """
thomwolf's avatar
thomwolf committed
136
137
138
139
140
        if not isinstance(inputs, (tuple, list)):
            input_ids = inputs
            position_ids = None
        else:
            input_ids, position_ids = inputs
thomwolf's avatar
thomwolf committed
141
142
143
144
145

        seq_length = tf.shape(input_ids)[1]
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]

thomwolf's avatar
thomwolf committed
146
        word_embeddings = tf.gather(self.word_embeddings, input_ids)
thomwolf's avatar
thomwolf committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)

        embeddings = word_embeddings + position_embeddings            # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)                       # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings, training=training)      # (bs, max_seq_length, dim)
        return embeddings

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [batch_size, length, hidden_size]
            Returns:
                float32 tensor with shape [batch_size, length, vocab_size].
        """
        batch_size = tf.shape(inputs)[0]
        length = tf.shape(inputs)[1]

thomwolf's avatar
thomwolf committed
164
        x = tf.reshape(inputs, [-1, self.dim])
thomwolf's avatar
thomwolf committed
165
166
167
168
169
170
171
172
173
174
175
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])


class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFMultiHeadSelfAttention, self).__init__(**kwargs)

        self.n_heads = config.n_heads
        self.dim = config.dim
thomwolf's avatar
thomwolf committed
176
        self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
thomwolf's avatar
thomwolf committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        self.output_attentions = config.output_attentions

        assert self.dim % self.n_heads == 0

        self.q_lin = tf.keras.layers.Dense(config.dim, name="q_lin")
        self.k_lin = tf.keras.layers.Dense(config.dim, name="k_lin")
        self.v_lin = tf.keras.layers.Dense(config.dim, name="v_lin")
        self.out_lin = tf.keras.layers.Dense(config.dim, name="out_lin")

        self.pruned_heads = set()

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, inputs, training=False):
        """
        Parameters
        ----------
thomwolf's avatar
thomwolf committed
195
196
197
198
        query: tf.Tensor(bs, seq_length, dim)
        key: tf.Tensor(bs, seq_length, dim)
        value: tf.Tensor(bs, seq_length, dim)
        mask: tf.Tensor(bs, seq_length)
thomwolf's avatar
thomwolf committed
199
200
201

        Outputs
        -------
thomwolf's avatar
thomwolf committed
202
        weights: tf.Tensor(bs, n_heads, seq_length, seq_length)
thomwolf's avatar
thomwolf committed
203
            Attention weights
thomwolf's avatar
thomwolf committed
204
        context: tf.Tensor(bs, seq_length, dim)
thomwolf's avatar
thomwolf committed
205
206
207
208
209
210
211
212
213
214
215
216
            Contextualized layer. Optional: only if `output_attentions=True`
        """
        query, key, value, mask, head_mask = inputs
        bs, q_length, dim = shape_list(query)
        k_length = shape_list(key)[1]
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()

        dim_per_head = self.dim // self.n_heads

        assert 2 <= len(tf.shape(mask)) <= 3
        causal = (len(tf.shape(mask)) == 3)
thomwolf's avatar
thomwolf committed
217
        mask_reshape = [bs, 1, 1, k_length]
thomwolf's avatar
thomwolf committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

        def shape(x):
            """ separate heads """
            return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))

        def unshape(x):
            """ group heads """
            return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))

        q = shape(self.q_lin(query))           # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))             # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))           # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)                     # (bs, n_heads, q_length, dim_per_head)
        scores = tf.matmul(q, k, transpose_b=True)          # (bs, n_heads, q_length, k_length)
        mask = tf.reshape(mask, mask_reshape)                           # (bs, n_heads, qlen, klen)
        # scores.masked_fill_(mask, -float('inf'))            # (bs, n_heads, q_length, k_length)
        scores = scores - 1e30 * (1.0 - mask)

        weights = tf.nn.softmax(scores, axis=-1)                              # (bs, n_heads, qlen, klen)
        weights = self.dropout(weights, training=training)                    # (bs, n_heads, qlen, klen)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = tf.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)             # (bs, q_length, dim)
        context = self.out_lin(context)        # (bs, q_length, dim)

        if self.output_attentions:
            return (context, weights)
        else:
            return (context,)

class TFFFN(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFFFN, self).__init__(**kwargs)
        self.dropout = tf.keras.layers.Dropout(config.dropout)
        self.lin1 = tf.keras.layers.Dense(config.hidden_dim, name="lin1")
        self.lin2 = tf.keras.layers.Dense(config.dim, name="lin2")
        assert config.activation in ['relu', 'gelu'], "activation ({}) must be in ['relu', 'gelu']".format(config.activation)
        self.activation = tf.keras.layers.Activation(gelu) if config.activation=='gelu' else tf.keras.activations.relu

    def call(self, input, training=False):
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x, training=training)
        return x


class TFTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFTransformerBlock, self).__init__(**kwargs)

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        self.dropout = tf.keras.layers.Dropout(config.dropout)
        self.activation = config.activation
        self.output_attentions = config.output_attentions

        assert config.dim % config.n_heads == 0

        self.attention = TFMultiHeadSelfAttention(config, name="attention")
        self.sa_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm")

        self.ffn = TFFFN(config, name="ffn")
        self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")

    def call(self, inputs, training=False):  # removed: src_enc=None, src_len=None
        """
        Parameters
        ----------
thomwolf's avatar
thomwolf committed
293
294
        x: tf.Tensor(bs, seq_length, dim)
        attn_mask: tf.Tensor(bs, seq_length)
thomwolf's avatar
thomwolf committed
295
296
297

        Outputs
        -------
thomwolf's avatar
thomwolf committed
298
        sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length)
thomwolf's avatar
thomwolf committed
299
            The attention weights
thomwolf's avatar
thomwolf committed
300
        ffn_output: tf.Tensor(bs, seq_length, dim)
thomwolf's avatar
thomwolf committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            The output of the transformer block contextualization.
        """
        x, attn_mask, head_mask = inputs

        # Self-Attention
        sa_output = self.attention([x, x, x, attn_mask, head_mask], training=training)
        if self.output_attentions:
            sa_output, sa_weights = sa_output                  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
            # assert type(sa_output) == tuple
            sa_output = sa_output[0]
        sa_output = self.sa_layer_norm(sa_output + x)          # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(sa_output, training=training)                             # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if self.output_attentions:
            output = (sa_weights,) + output
        return output


class TFTransformer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFTransformer, self).__init__(**kwargs)
        self.n_layers = config.n_layers
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        self.layer = [TFTransformerBlock(config, name='layer_._{}'.format(i))
                      for i in range(config.n_layers)]

thomwolf's avatar
thomwolf committed
334
    def call(self, inputs, training=False):
thomwolf's avatar
thomwolf committed
335
336
337
        """
        Parameters
        ----------
thomwolf's avatar
thomwolf committed
338
        x: tf.Tensor(bs, seq_length, dim)
thomwolf's avatar
thomwolf committed
339
            Input sequence embedded.
thomwolf's avatar
thomwolf committed
340
        attn_mask: tf.Tensor(bs, seq_length)
thomwolf's avatar
thomwolf committed
341
342
343
344
            Attention mask on the sequence.

        Outputs
        -------
thomwolf's avatar
thomwolf committed
345
        hidden_state: tf.Tensor(bs, seq_length, dim)
thomwolf's avatar
thomwolf committed
346
            Sequence of hiddens states in the last (top) layer
thomwolf's avatar
thomwolf committed
347
        all_hidden_states: Tuple[tf.Tensor(bs, seq_length, dim)]
thomwolf's avatar
thomwolf committed
348
349
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if output_hidden_states=True
thomwolf's avatar
thomwolf committed
350
        all_attentions: Tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)]
thomwolf's avatar
thomwolf committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if output_attentions=True
        """
        x, attn_mask, head_mask = inputs

        all_hidden_states = ()
        all_attentions = ()

        hidden_state = x
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i]], training=training)
            hidden_state = layer_outputs[-1]

            if self.output_attentions:
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions,)
            else:
                assert len(layer_outputs) == 1

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        outputs = (hidden_state,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)


class TFDistilBertMainLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super(TFDistilBertMainLayer, self).__init__(**kwargs)
thomwolf's avatar
thomwolf committed
389
        self.num_hidden_layers = config.num_hidden_layers
thomwolf's avatar
thomwolf committed
390
391
392
393
394
395
396
397
398
399

        self.embeddings = TFEmbeddings(config, name="embeddings")   # Embeddings
        self.transformer = TFTransformer(config, name="transformer") # Encoder

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        raise NotImplementedError

thomwolf's avatar
thomwolf committed
400
401
    def call(self, inputs, attention_mask=None, head_mask=None, training=False):
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
402
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
403
404
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            head_mask = inputs[2] if len(inputs) > 2 else head_mask
thomwolf's avatar
thomwolf committed
405
            assert len(inputs) <= 3, "Too many inputs."
thomwolf's avatar
thomwolf committed
406
        elif isinstance(inputs, dict):
thomwolf's avatar
thomwolf committed
407
            input_ids = inputs.get('input_ids')
thomwolf's avatar
thomwolf committed
408
409
            attention_mask = inputs.get('attention_mask', attention_mask)
            head_mask = inputs.get('head_mask', head_mask)
thomwolf's avatar
thomwolf committed
410
            assert len(inputs) <= 3, "Too many inputs."
thomwolf's avatar
thomwolf committed
411
412
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
413
414
415

        if attention_mask is None:
            attention_mask = tf.ones(shape_list(input_ids)) # (bs, seq_length)
thomwolf's avatar
thomwolf committed
416
        attention_mask = tf.cast(attention_mask, dtype=tf.float32)
thomwolf's avatar
thomwolf committed
417
418
419
420
421
422
423
424
425

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
thomwolf's avatar
thomwolf committed
426
            head_mask = [None] * self.num_hidden_layers
thomwolf's avatar
thomwolf committed
427
428
429
430

        embedding_output = self.embeddings(input_ids)   # (bs, seq_length, dim)
        tfmr_output = self.transformer([embedding_output, attention_mask, head_mask], training=training)

thomwolf's avatar
thomwolf committed
431
        return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
thomwolf's avatar
thomwolf committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458


### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    """
    config_class = DistilBertConfig
    pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
    load_pt_weights = load_distilbert_pt_weights_in_tf2
    base_model_prefix = "distilbert"


DISTILBERT_START_DOCSTRING = r"""
    DistilBERT is a small, fast, cheap and light Transformer model
    trained by distilling Bert base. It has 40% less parameters than
    `bert-base-uncased`, runs 60% faster while preserving over 95% of
    Bert's performances as measured on the GLUE language understanding benchmark.

    Here are the differences between the interface of Bert and DistilBert:

    - DistilBert doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
    - DistilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.

    For more information on DistilBERT, please refer to our
    `detailed blog post`_
    
thomwolf's avatar
thomwolf committed
459
460
461
    This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

thomwolf's avatar
thomwolf committed
462
463
464
    .. _`detailed blog post`:
        https://medium.com/huggingface/distilbert-8cf3380435b5

thomwolf's avatar
thomwolf committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    .. _`tf.keras.Model`:
        https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model

    Note on the model inputs:
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

        This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.

        If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :

        - a single Tensor with input_ids only and nothing else: `model(inputs_ids)
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
            `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
            `model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`

thomwolf's avatar
thomwolf committed
484
485
486
487
488
489
490
491
    Parameters:
        config (:class:`~pytorch_transformers.DistilBertConfig`): Model configuration class with all the parameters of the model. 
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

DISTILBERT_INPUTS_DOCSTRING = r"""
    Inputs:
thomwolf's avatar
thomwolf committed
492
        **input_ids** ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
493
494
495
496
            Indices of input sequence tokens in the vocabulary.
            The input sequences should start with `[CLS]` and end with `[SEP]` tokens.
            
            For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DistilBERT.
thomwolf's avatar
thomwolf committed
497
        **attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
498
499
500
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
501
        **head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
502
503
504
505
506
507
508
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare DistilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
                      DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
509
class TFDistilBertModel(TFDistilBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
510
511
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
512
        **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
thomwolf's avatar
thomwolf committed
513
514
            Sequence of hidden-states at the output of the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
515
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
516
517
518
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
519
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
520
521
522
523
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
524
525
526
        import tensorflow as tf
        from pytorch_transformers import DistilBertTokenizer, TFDistilBertModel

thomwolf's avatar
thomwolf committed
527
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
thomwolf's avatar
thomwolf committed
528
529
        model = TFDistilBertModel.from_pretrained('distilbert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
530
531
532
533
534
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
    def __init__(self, config, *inputs, **kwargs):
thomwolf's avatar
thomwolf committed
535
        super(TFDistilBertModel, self).__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
536
537
        self.distilbert = TFDistilBertMainLayer(config, name="distilbert")   # Embeddings

thomwolf's avatar
thomwolf committed
538
539
    def call(self, inputs, **kwargs):
        outputs = self.distilbert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
540
541
542
        return outputs


thomwolf's avatar
thomwolf committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
class TFDistilBertLMHead(tf.keras.layers.Layer):
    def __init__(self, config, input_embeddings, **kwargs):
        super(TFDistilBertLMHead, self).__init__(**kwargs)
        self.vocab_size = config.vocab_size

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.input_embeddings = input_embeddings

    def build(self, input_shape):
        self.bias = self.add_weight(shape=(self.vocab_size,),
                                    initializer='zeros',
                                    trainable=True,
                                    name='bias')
        super(TFDistilBertLMHead, self).build(input_shape)

    def call(self, hidden_states):
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
        return hidden_states


thomwolf's avatar
thomwolf committed
565
566
567
568
569
@add_start_docstrings("""DistilBert Model with a `masked language modeling` head on top. """,
                      DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
570
        **prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
thomwolf's avatar
thomwolf committed
571
572
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
573
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
574
575
576
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
577
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
578
579
580
581
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
582
583
584
        import tensorflow as tf
        from pytorch_transformers import DistilBertTokenizer, TFDistilBertForMaskedLM

thomwolf's avatar
thomwolf committed
585
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
thomwolf's avatar
thomwolf committed
586
587
        model = TFDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
588
        outputs = model(input_ids, masked_lm_labels=input_ids)
thomwolf's avatar
thomwolf committed
589
        prediction_scores = outputs[0]
thomwolf's avatar
thomwolf committed
590
591
592
593
594
595

    """
    def __init__(self, config, *inputs, **kwargs):
        super(TFDistilBertForMaskedLM, self).__init__(config, *inputs, **kwargs)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
thomwolf's avatar
thomwolf committed
596
        self.vocab_size = config.vocab_size
thomwolf's avatar
thomwolf committed
597
598
599
600
601

        self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
        self.vocab_transform = tf.keras.layers.Dense(config.dim, name="vocab_transform")
        self.act = tf.keras.layers.Activation(gelu)
        self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
thomwolf's avatar
thomwolf committed
602
        self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
thomwolf's avatar
thomwolf committed
603

thomwolf's avatar
thomwolf committed
604
605
    def call(self, inputs, **kwargs):
        distilbert_output = self.distilbert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
606

thomwolf's avatar
thomwolf committed
607
        hidden_states = distilbert_output[0]                               # (bs, seq_length, dim)
thomwolf's avatar
thomwolf committed
608
609
610
611
        prediction_logits = self.vocab_transform(hidden_states)       # (bs, seq_length, dim)
        prediction_logits = self.act(prediction_logits)               # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_projector(prediction_logits)
thomwolf's avatar
thomwolf committed
612

thomwolf's avatar
thomwolf committed
613
614
        outputs = (prediction_logits,) + distilbert_output[1:]
        return outputs  # logits, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
615
616
617
618
619
620
621
622


@add_start_docstrings("""DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
                         the pooled output) e.g. for GLUE tasks. """,
                      DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
623
        **logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
thomwolf's avatar
thomwolf committed
624
625
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
626
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
627
628
629
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
630
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
631
632
633
634
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
635
636
637
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFDistilBertForSequenceClassification

thomwolf's avatar
thomwolf committed
638
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
thomwolf's avatar
thomwolf committed
639
640
641
642
        model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        outputs = model(input_ids)
        logits = outputs[0]
thomwolf's avatar
thomwolf committed
643
644
645
646
647
648

    """
    def __init__(self, config, *inputs, **kwargs):
        super(TFDistilBertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

649
        self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
thomwolf's avatar
thomwolf committed
650
651
652
653
        self.pre_classifier = tf.keras.layers.Dense(config.dim, activation='relu', name="pre_classifier")
        self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
        self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)

thomwolf's avatar
thomwolf committed
654
655
656
    def call(self, inputs, **kwargs):
        distilbert_output = self.distilbert(inputs, **kwargs)

thomwolf's avatar
thomwolf committed
657
658
659
        hidden_state = distilbert_output[0]                    # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]                    # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)   # (bs, dim)
thomwolf's avatar
thomwolf committed
660
        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))         # (bs, dim)
thomwolf's avatar
thomwolf committed
661
662
663
664
665
666
667
668
669
670
671
672
        logits = self.classifier(pooled_output)              # (bs, dim)

        outputs = (logits,) + distilbert_output[1:]
        return outputs  # logits, (hidden_states), (attentions)


@add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
                         the hidden-states output to compute `span start logits` and `span end logits`). """,
                      DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
673
        **start_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
thomwolf's avatar
thomwolf committed
674
            Span-start scores (before SoftMax).
thomwolf's avatar
thomwolf committed
675
        **end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
thomwolf's avatar
thomwolf committed
676
677
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
thomwolf's avatar
thomwolf committed
678
            list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
thomwolf's avatar
thomwolf committed
679
680
681
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
thomwolf's avatar
thomwolf committed
682
            list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
thomwolf's avatar
thomwolf committed
683
684
685
686
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
687
688
689
        import tensorflow as tf
        from pytorch_transformers import BertTokenizer, TFDistilBertForQuestionAnswering

thomwolf's avatar
thomwolf committed
690
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
thomwolf's avatar
thomwolf committed
691
692
693
694
        model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        start_positions = tf.constant([1])
        end_positions = tf.constant([3])
thomwolf's avatar
thomwolf committed
695
        outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
thomwolf's avatar
thomwolf committed
696
        start_scores, end_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
697
698
699
700
701

    """
    def __init__(self, config, *inputs, **kwargs):
        super(TFDistilBertForQuestionAnswering, self).__init__(config, *inputs, **kwargs)

702
703
        self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
        self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
thomwolf's avatar
thomwolf committed
704
705
706
        assert config.num_labels == 2
        self.dropout = tf.keras.layers.Dropout(config.qa_dropout)

thomwolf's avatar
thomwolf committed
707
708
    def call(self, inputs, **kwargs):
        distilbert_output = self.distilbert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
709

thomwolf's avatar
thomwolf committed
710
711
        hidden_states = distilbert_output[0]                                 # (bs, max_query_len, dim)
        hidden_states = self.dropout(hidden_states, training=kwargs.get('training', False))                       # (bs, max_query_len, dim)
thomwolf's avatar
thomwolf committed
712
713
714
715
716
717
718
        logits = self.qa_outputs(hidden_states)                           # (bs, max_query_len, 2)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        outputs = (start_logits, end_logits,) + distilbert_output[1:]
        return outputs  # start_logits, end_logits, (hidden_states), (attentions)