"tests/tokenization_t5_test.py" did not exist on "73f2c342f53f2ff02124da23ba029d80c386e7ce"
modeling_distilbert.py 37.7 KB
Newer Older
VictorSanh's avatar
wip  
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
VictorSanh's avatar
wip  
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
15
16
17
""" PyTorch DistilBERT model
    adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
    and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
VictorSanh's avatar
wip  
VictorSanh committed
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

VictorSanh's avatar
wip  
VictorSanh committed
20

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import copy
VictorSanh's avatar
wip  
VictorSanh committed
22
23
24
25
26
27
import logging
import math

import numpy as np
import torch
import torch.nn as nn
28
from torch.nn import CrossEntropyLoss
VictorSanh's avatar
wip  
VictorSanh committed
29

30
31
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings
Aymeric Augustin's avatar
Aymeric Augustin committed
32
from .modeling_utils import PreTrainedModel, prune_linear_layer
VictorSanh's avatar
wip  
VictorSanh committed
33

34

VictorSanh's avatar
wip  
VictorSanh committed
35
36
37
logger = logging.getLogger(__name__)


thomwolf's avatar
thomwolf committed
38
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
39
40
41
42
    "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-pytorch_model.bin",
    "distilbert-base-uncased-distilled-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin",
    "distilbert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-pytorch_model.bin",
    "distilbert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-multilingual-cased-pytorch_model.bin",
VictorSanh's avatar
wip  
VictorSanh committed
43
44
45
}


46
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
VictorSanh's avatar
wip  
VictorSanh committed
47
48
49
def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))

50

VictorSanh's avatar
wip  
VictorSanh committed
51
def create_sinusoidal_embeddings(n_pos, dim, out):
52
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
VictorSanh's avatar
wip  
VictorSanh committed
53
54
55
56
57
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False

58

VictorSanh's avatar
wip  
VictorSanh committed
59
class Embeddings(nn.Module):
60
    def __init__(self, config):
VictorSanh's avatar
wip  
VictorSanh committed
61
        super(Embeddings, self).__init__()
VictorSanh's avatar
VictorSanh committed
62
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=0)
VictorSanh's avatar
wip  
VictorSanh committed
63
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
VictorSanh's avatar
VictorSanh committed
64
        if config.sinusoidal_pos_embds:
65
66
67
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )
VictorSanh's avatar
wip  
VictorSanh committed
68
69
70
71
72
73
74
75

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        """
        Parameters
        ----------
VictorSanh's avatar
VictorSanh committed
76
77
78
79
80
81
82
        input_ids: torch.tensor(bs, max_seq_length)
            The token ids to embed.

        Outputs
        -------
        embeddings: torch.tensor(bs, max_seq_length, dim)
            The embedded tokens (plus position embeddings, no token_type embeddings)
VictorSanh's avatar
wip  
VictorSanh committed
83
84
        """
        seq_length = input_ids.size(1)
85
86
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
87

88
89
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
90

VictorSanh's avatar
VictorSanh committed
91
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
92
93
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
94
95
        return embeddings

96

VictorSanh's avatar
wip  
VictorSanh committed
97
class MultiHeadSelfAttention(nn.Module):
LysandreJik's avatar
LysandreJik committed
98
    def __init__(self, config):
VictorSanh's avatar
wip  
VictorSanh committed
99
100
101
102
103
104
105
106
107
        super(MultiHeadSelfAttention, self).__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)
        self.output_attentions = config.output_attentions

        assert self.dim % self.n_heads == 0

VictorSanh's avatar
VictorSanh committed
108
109
110
111
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
VictorSanh's avatar
wip  
VictorSanh committed
112

113
114
        self.pruned_heads = set()

115
116
117
118
119
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_heads, attention_head_size)
120
        heads = set(heads) - self.pruned_heads
121
        for head in heads:
122
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
123
124
125
126
127
128
129
130
131
132
133
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
134
        self.pruned_heads = self.pruned_heads.union(heads)
135

136
    def forward(self, query, key, value, mask, head_mask=None):
VictorSanh's avatar
wip  
VictorSanh committed
137
138
139
140
141
142
143
144
        """
        Parameters
        ----------
        query: torch.tensor(bs, seq_length, dim)
        key: torch.tensor(bs, seq_length, dim)
        value: torch.tensor(bs, seq_length, dim)
        mask: torch.tensor(bs, seq_length)

VictorSanh's avatar
VictorSanh committed
145
146
        Outputs
        -------
VictorSanh's avatar
wip  
VictorSanh committed
147
148
149
        weights: torch.tensor(bs, n_heads, seq_length, seq_length)
            Attention weights
        context: torch.tensor(bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
150
            Contextualized layer. Optional: only if `output_attentions=True`
VictorSanh's avatar
wip  
VictorSanh committed
151
152
153
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
154
155
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()
VictorSanh's avatar
wip  
VictorSanh committed
156

157
        dim_per_head = self.dim // self.n_heads
VictorSanh's avatar
wip  
VictorSanh committed
158
159
160
161
162
163
164
165
166

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
167
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
168

169
170
171
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
172

173
174
175
176
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
VictorSanh's avatar
wip  
VictorSanh committed
177

178
179
        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
180
181
182
183
184

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

185
186
187
        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
188
189

        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
190
            return (context, weights)
VictorSanh's avatar
wip  
VictorSanh committed
191
        else:
VictorSanh's avatar
VictorSanh committed
192
            return (context,)
VictorSanh's avatar
wip  
VictorSanh committed
193

194

VictorSanh's avatar
wip  
VictorSanh committed
195
class FFN(nn.Module):
LysandreJik's avatar
LysandreJik committed
196
    def __init__(self, config):
VictorSanh's avatar
wip  
VictorSanh committed
197
198
199
200
        super(FFN, self).__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
201
202
203
204
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = gelu if config.activation == "gelu" else nn.ReLU()
VictorSanh's avatar
wip  
VictorSanh committed
205

LysandreJik's avatar
LysandreJik committed
206
    def forward(self, input):
VictorSanh's avatar
wip  
VictorSanh committed
207
208
209
210
211
212
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

213

VictorSanh's avatar
wip  
VictorSanh committed
214
class TransformerBlock(nn.Module):
LysandreJik's avatar
LysandreJik committed
215
    def __init__(self, config):
VictorSanh's avatar
wip  
VictorSanh committed
216
217
218
219
220
221
222
223
224
        super(TransformerBlock, self).__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        self.dropout = nn.Dropout(p=config.dropout)
        self.activation = config.activation
        self.output_attentions = config.output_attentions

VictorSanh's avatar
VictorSanh committed
225
        assert config.dim % config.n_heads == 0
VictorSanh's avatar
wip  
VictorSanh committed
226

VictorSanh's avatar
VictorSanh committed
227
        self.attention = MultiHeadSelfAttention(config)
VictorSanh's avatar
wip  
VictorSanh committed
228
229
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

VictorSanh's avatar
VictorSanh committed
230
        self.ffn = FFN(config)
VictorSanh's avatar
wip  
VictorSanh committed
231
232
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

LysandreJik's avatar
LysandreJik committed
233
    def forward(self, x, attn_mask=None, head_mask=None):
VictorSanh's avatar
wip  
VictorSanh committed
234
235
236
237
238
        """
        Parameters
        ----------
        x: torch.tensor(bs, seq_length, dim)
        attn_mask: torch.tensor(bs, seq_length)
VictorSanh's avatar
VictorSanh committed
239
240
241
242
243
244
245

        Outputs
        -------
        sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
            The attention weights
        ffn_output: torch.tensor(bs, seq_length, dim)
            The output of the transformer block contextualization.
VictorSanh's avatar
wip  
VictorSanh committed
246
247
        """
        # Self-Attention
248
        sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask)
VictorSanh's avatar
wip  
VictorSanh committed
249
        if self.output_attentions:
250
251
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attention` or `output_hidden_states` cases returning tuples
VictorSanh's avatar
VictorSanh committed
252
            assert type(sa_output) == tuple
VictorSanh's avatar
VictorSanh committed
253
            sa_output = sa_output[0]
254
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
255
256

        # Feed Forward Network
257
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
258
259
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

VictorSanh's avatar
VictorSanh committed
260
        output = (ffn_output,)
VictorSanh's avatar
wip  
VictorSanh committed
261
        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
262
263
            output = (sa_weights,) + output
        return output
VictorSanh's avatar
wip  
VictorSanh committed
264

265

VictorSanh's avatar
wip  
VictorSanh committed
266
class Transformer(nn.Module):
LysandreJik's avatar
LysandreJik committed
267
    def __init__(self, config):
VictorSanh's avatar
wip  
VictorSanh committed
268
269
270
        super(Transformer, self).__init__()
        self.n_layers = config.n_layers
        self.output_attentions = config.output_attentions
VictorSanh's avatar
VictorSanh committed
271
        self.output_hidden_states = config.output_hidden_states
VictorSanh's avatar
wip  
VictorSanh committed
272

VictorSanh's avatar
VictorSanh committed
273
274
        layer = TransformerBlock(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
VictorSanh's avatar
wip  
VictorSanh committed
275

LysandreJik's avatar
LysandreJik committed
276
    def forward(self, x, attn_mask=None, head_mask=None):
VictorSanh's avatar
wip  
VictorSanh committed
277
278
279
280
        """
        Parameters
        ----------
        x: torch.tensor(bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
281
            Input sequence embedded.
VictorSanh's avatar
wip  
VictorSanh committed
282
        attn_mask: torch.tensor(bs, seq_length)
VictorSanh's avatar
VictorSanh committed
283
284
285
286
287
288
289
290
291
292
293
294
            Attention mask on the sequence.

        Outputs
        -------
        hidden_state: torch.tensor(bs, seq_length, dim)
            Sequence of hiddens states in the last (top) layer
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if output_hidden_states=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if output_attentions=True
VictorSanh's avatar
wip  
VictorSanh committed
295
        """
VictorSanh's avatar
VictorSanh committed
296
297
        all_hidden_states = ()
        all_attentions = ()
VictorSanh's avatar
wip  
VictorSanh committed
298

VictorSanh's avatar
VictorSanh committed
299
        hidden_state = x
300
301
302
303
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

304
            layer_outputs = layer_module(x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i])
305
306
            hidden_state = layer_outputs[-1]

VictorSanh's avatar
wip  
VictorSanh committed
307
            if self.output_attentions:
308
309
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
VictorSanh's avatar
VictorSanh committed
310
                all_attentions = all_attentions + (attentions,)
311
312
313
314
315
            else:
                assert len(layer_outputs) == 1

        # Add last layer
        if self.output_hidden_states:
VictorSanh's avatar
VictorSanh committed
316
            all_hidden_states = all_hidden_states + (hidden_state,)
VictorSanh's avatar
wip  
VictorSanh committed
317

VictorSanh's avatar
VictorSanh committed
318
319
320
        outputs = (hidden_state,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
VictorSanh's avatar
wip  
VictorSanh committed
321
        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
322
            outputs = outputs + (all_attentions,)
323
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
VictorSanh's avatar
VictorSanh committed
324
325


326
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
thomwolf's avatar
thomwolf committed
327
class DistilBertPreTrainedModel(PreTrainedModel):
VictorSanh's avatar
VictorSanh committed
328
329
330
    """ An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    """
331

thomwolf's avatar
thomwolf committed
332
333
    config_class = DistilBertConfig
    pretrained_model_archive_map = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
VictorSanh's avatar
VictorSanh committed
334
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
335
    base_model_prefix = "distilbert"
VictorSanh's avatar
VictorSanh committed
336

337
    def _init_weights(self, module):
VictorSanh's avatar
VictorSanh committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        """ Initialize the weights.
        """
        if isinstance(module, nn.Embedding):
            if module.weight.requires_grad:
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


thomwolf's avatar
thomwolf committed
352
353
DISTILBERT_START_DOCSTRING = r"""
    DistilBERT is a small, fast, cheap and light Transformer model
354
355
356
357
    trained by distilling Bert base. It has 40% less parameters than
    `bert-base-uncased`, runs 60% faster while preserving over 95% of
    Bert's performances as measured on the GLUE language understanding benchmark.

thomwolf's avatar
thomwolf committed
358
    Here are the differences between the interface of Bert and DistilBert:
359

LysandreJik's avatar
LysandreJik committed
360
    - DistilBert doesn't have `token_type_ids`, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token `tokenizer.sep_token` (or `[SEP]`)
thomwolf's avatar
thomwolf committed
361
    - DistilBert doesn't have options to select the input positions (`position_ids` input). This could be added if necessary though, just let's us know if you need this option.
VictorSanh's avatar
VictorSanh committed
362

thomwolf's avatar
thomwolf committed
363
    For more information on DistilBERT, please refer to our
364
    `detailed blog post`_
365

366
    .. _`detailed blog post`:
LysandreJik's avatar
LysandreJik committed
367
        https://medium.com/huggingface/distilbert-8cf3380435b5
VictorSanh's avatar
VictorSanh committed
368
369

    Parameters:
370
        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
VictorSanh's avatar
VictorSanh committed
371
            Initializing with a config file does not load the weights associated with the model, only the configuration.
372
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
VictorSanh's avatar
VictorSanh committed
373
374
"""

thomwolf's avatar
thomwolf committed
375
DISTILBERT_INPUTS_DOCSTRING = r"""
VictorSanh's avatar
VictorSanh committed
376
    Inputs:
LysandreJik's avatar
LysandreJik committed
377
378
379
        **input_ids** ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            The input sequences should start with `[CLS]` and end with `[SEP]` tokens.
380

thomwolf's avatar
thomwolf committed
381
            For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DistilBERT.
VictorSanh's avatar
VictorSanh committed
382
383
384
385
        **attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
386
387
388
389
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
390
391
392
393
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
VictorSanh's avatar
VictorSanh committed
394
395
"""

396
397
398
399
400
401

@add_start_docstrings(
    "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
    DISTILBERT_START_DOCSTRING,
    DISTILBERT_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
402
class DistilBertModel(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
403
    r"""
404
405
406
407
408
409
410
411
412
413
414
415
416
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the output of the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
417
418
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        model = DistilBertModel.from_pretrained('distilbert-base-uncased')
419
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
420
421
422
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

VictorSanh's avatar
VictorSanh committed
423
    """
424

VictorSanh's avatar
VictorSanh committed
425
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
426
        super(DistilBertModel, self).__init__(config)
VictorSanh's avatar
VictorSanh committed
427

428
429
        self.embeddings = Embeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder
VictorSanh's avatar
VictorSanh committed
430

431
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
432

thomwolf's avatar
thomwolf committed
433
    def get_input_embeddings(self):
434
435
        return self.embeddings.word_embeddings

thomwolf's avatar
thomwolf committed
436
    def set_input_embeddings(self, new_embeddings):
437
438
        self.embeddings.word_embeddings = new_embeddings

439
440
441
442
443
444
445
446
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

447
    def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None):
448
449
450
451
452
453
454
455
456
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

Julien Chaumond's avatar
Julien Chaumond committed
457
458
        device = input_ids.device if input_ids is not None else inputs_embeds.device

VictorSanh's avatar
VictorSanh committed
459
        if attention_mask is None:
460
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
461

462
463
464
465
466
467
468
469
470
471
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
472
473
474
475
476
477
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
478
479
480
        else:
            head_mask = [None] * self.config.num_hidden_layers

481
        if inputs_embeds is None:
482
483
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
        tfmr_output = self.transformer(x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask)
VictorSanh's avatar
VictorSanh committed
484
        hidden_state = tfmr_output[0]
485
        output = (hidden_state,) + tfmr_output[1:]
486

487
        return output  # last-layer hidden-state, (all hidden_states), (all attentions)
VictorSanh's avatar
wip  
VictorSanh committed
488
489


490
491
492
493
494
@add_start_docstrings(
    """DistilBert Model with a `masked language modeling` head on top. """,
    DISTILBERT_START_DOCSTRING,
    DISTILBERT_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
495
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
496
    r"""
497
498
499
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
LysandreJik's avatar
LysandreJik committed
500
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
518
519
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
520
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
521
522
        outputs = model(input_ids, masked_lm_labels=input_ids)
        loss, prediction_scores = outputs[:2]
VictorSanh's avatar
VictorSanh committed
523
524

    """
525

VictorSanh's avatar
VictorSanh committed
526
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
527
        super(DistilBertForMaskedLM, self).__init__(config)
VictorSanh's avatar
VictorSanh committed
528
529
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
VictorSanh's avatar
wip  
VictorSanh committed
530

thomwolf's avatar
thomwolf committed
531
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
532
533
534
535
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

536
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
537

LysandreJik's avatar
LysandreJik committed
538
        self.mlm_loss_fct = nn.CrossEntropyLoss()
VictorSanh's avatar
VictorSanh committed
539

thomwolf's avatar
thomwolf committed
540
    def get_output_embeddings(self):
541
        return self.vocab_projector
VictorSanh's avatar
VictorSanh committed
542

543
    def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, masked_lm_labels=None):
544
545
546
547
548
549
550
        dlbrt_output = self.distilbert(
            input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
        )
        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
551
552
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

553
        outputs = (prediction_logits,) + dlbrt_output[1:]
VictorSanh's avatar
VictorSanh committed
554
        if masked_lm_labels is not None:
555
556
557
558
            mlm_loss = self.mlm_loss_fct(
                prediction_logits.view(-1, prediction_logits.size(-1)), masked_lm_labels.view(-1)
            )
            outputs = (mlm_loss,) + outputs
VictorSanh's avatar
VictorSanh committed
559

560
        return outputs  # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
561

VictorSanh's avatar
VictorSanh committed
562

563
564
@add_start_docstrings(
    """DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
VictorSanh's avatar
VictorSanh committed
565
                         the pooled output) e.g. for GLUE tasks. """,
566
567
568
    DISTILBERT_START_DOCSTRING,
    DISTILBERT_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
569
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
570
    r"""
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
592
593
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
594
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
595
596
597
598
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]

VictorSanh's avatar
VictorSanh committed
599
    """
600

VictorSanh's avatar
VictorSanh committed
601
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
602
        super(DistilBertForSequenceClassification, self).__init__(config)
VictorSanh's avatar
VictorSanh committed
603
604
        self.num_labels = config.num_labels

thomwolf's avatar
thomwolf committed
605
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
606
607
608
609
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

610
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
611

612
    def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
613
614
615
616
617
618
619
620
621
        distilbert_output = self.distilbert(
            input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, dim)
VictorSanh's avatar
VictorSanh committed
622

thomwolf's avatar
thomwolf committed
623
        outputs = (logits,) + distilbert_output[1:]
VictorSanh's avatar
VictorSanh committed
624
625
626
627
628
629
630
631
632
633
634
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

635

636
637
@add_start_docstrings(
    """DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
VictorSanh's avatar
VictorSanh committed
638
                         the hidden-states output to compute `span start logits` and `span end logits`). """,
639
640
641
    DISTILBERT_START_DOCSTRING,
    DISTILBERT_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
642
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
643
    r"""
644
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
VictorSanh's avatar
VictorSanh committed
645
646
647
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
648
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
VictorSanh's avatar
VictorSanh committed
649
650
651
652
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
670
671
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
672
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
673
674
675
        start_positions = torch.tensor([1])
        end_positions = torch.tensor([3])
        outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
676
        loss, start_scores, end_scores = outputs[:3]
677

VictorSanh's avatar
VictorSanh committed
678
    """
679

VictorSanh's avatar
VictorSanh committed
680
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
681
        super(DistilBertForQuestionAnswering, self).__init__(config)
VictorSanh's avatar
VictorSanh committed
682

thomwolf's avatar
thomwolf committed
683
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
684
685
686
687
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout)

688
        self.init_weights()
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
    ):
        distilbert_output = self.distilbert(
            input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
        )
        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)
VictorSanh's avatar
wip  
VictorSanh committed
706
        start_logits, end_logits = logits.split(1, dim=-1)
707
708
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
VictorSanh's avatar
wip  
VictorSanh committed
709

thomwolf's avatar
thomwolf committed
710
        outputs = (start_logits, end_logits,) + distilbert_output[1:]
VictorSanh's avatar
wip  
VictorSanh committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

VictorSanh's avatar
VictorSanh committed
728
        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
729
730


731
732
@add_start_docstrings(
    """DistilBert Model with a token classification head on top (a linear layer on top of
733
                      the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
734
735
736
    DISTILBERT_START_DOCSTRING,
    DISTILBERT_INPUTS_DOCSTRING,
)
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, scores = outputs[:2]

    """
766

767
768
769
770
771
772
773
774
775
776
    def __init__(self, config):
        super(DistilBertForTokenClassification, self).__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

777
    def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None):
778

779
780
781
        outputs = self.distilbert(
            input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds
        )
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)