modeling_distilbert.py 38.3 KB
Newer Older
VictorSanh's avatar
wip  
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
VictorSanh's avatar
wip  
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
17
"""
 PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
 part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
VictorSanh's avatar
wip  
VictorSanh committed
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

VictorSanh's avatar
wip  
VictorSanh committed
20

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import copy
VictorSanh's avatar
wip  
VictorSanh committed
22
23
24
25
26
import math

import numpy as np
import torch
import torch.nn as nn
27
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
VictorSanh's avatar
wip  
VictorSanh committed
28

Sylvain Gugger's avatar
Sylvain Gugger committed
29
30
from ...activations import gelu
from ...file_utils import (
31
32
    add_code_sample_docstrings,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
35
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from ...modeling_outputs import (
37
38
39
40
41
42
43
    BaseModelOutput,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
VictorSanh's avatar
wip  
VictorSanh committed
52

53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
55
_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
56
_CONFIG_FOR_DOC = "DistilBertConfig"
57
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
VictorSanh's avatar
wip  
VictorSanh committed
58

59
60
61
62
63
64
65
66
67
68
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "distilbert-base-uncased",
    "distilbert-base-uncased-distilled-squad",
    "distilbert-base-cased",
    "distilbert-base-cased-distilled-squad",
    "distilbert-base-german-cased",
    "distilbert-base-multilingual-cased",
    "distilbert-base-uncased-finetuned-sst-2-english",
    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert
]
VictorSanh's avatar
wip  
VictorSanh committed
69
70


71
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
VictorSanh's avatar
wip  
VictorSanh committed
72

73

VictorSanh's avatar
wip  
VictorSanh committed
74
def create_sinusoidal_embeddings(n_pos, dim, out):
75
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
76
    out.requires_grad = False
VictorSanh's avatar
wip  
VictorSanh committed
77
78
79
80
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()

81

VictorSanh's avatar
wip  
VictorSanh committed
82
class Embeddings(nn.Module):
83
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
84
        super().__init__()
85
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
VictorSanh's avatar
wip  
VictorSanh committed
86
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
VictorSanh's avatar
VictorSanh committed
87
        if config.sinusoidal_pos_embds:
88
89
90
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )
VictorSanh's avatar
wip  
VictorSanh committed
91
92
93
94
95
96

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
        Parameters:
            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.

        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
        embeddings)
VictorSanh's avatar
wip  
VictorSanh committed
102
103
        """
        seq_length = input_ids.size(1)
104
105
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
106

107
108
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
109

VictorSanh's avatar
VictorSanh committed
110
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
111
112
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
113
114
        return embeddings

115

VictorSanh's avatar
wip  
VictorSanh committed
116
class MultiHeadSelfAttention(nn.Module):
LysandreJik's avatar
LysandreJik committed
117
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
118
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
119
120
121
122
123
124
125

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

VictorSanh's avatar
VictorSanh committed
126
127
128
129
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
VictorSanh's avatar
wip  
VictorSanh committed
130

131
132
        self.pruned_heads = set()

133
134
135
136
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
137
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
138
139
140
141
142
143
144
145
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
146
        self.pruned_heads = self.pruned_heads.union(heads)
147

148
    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
149
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
152
153
154
155
156
157
158
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
VictorSanh's avatar
wip  
VictorSanh committed
159
160
161
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
162
        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
163
        # assert key.size() == value.size()
VictorSanh's avatar
wip  
VictorSanh committed
164

165
        dim_per_head = self.dim // self.n_heads
VictorSanh's avatar
wip  
VictorSanh committed
166
167
168
169

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
Patrick von Platen's avatar
Patrick von Platen committed
170
            """separate heads"""
VictorSanh's avatar
wip  
VictorSanh committed
171
172
173
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
Patrick von Platen's avatar
Patrick von Platen committed
174
            """group heads"""
175
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
176

177
178
179
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
180

181
182
183
184
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
VictorSanh's avatar
wip  
VictorSanh committed
185

186
187
        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
188
189
190
191
192

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

193
194
195
        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
196

197
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
198
            return (context, weights)
VictorSanh's avatar
wip  
VictorSanh committed
199
        else:
VictorSanh's avatar
VictorSanh committed
200
            return (context,)
VictorSanh's avatar
wip  
VictorSanh committed
201

202

VictorSanh's avatar
wip  
VictorSanh committed
203
class FFN(nn.Module):
LysandreJik's avatar
LysandreJik committed
204
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
205
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
206
        self.dropout = nn.Dropout(p=config.dropout)
207
208
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
VictorSanh's avatar
wip  
VictorSanh committed
209
210
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
211
        assert config.activation in ["relu", "gelu"], f"activation ({config.activation}) must be in ['relu', 'gelu']"
212
        self.activation = gelu if config.activation == "gelu" else nn.ReLU()
VictorSanh's avatar
wip  
VictorSanh committed
213

LysandreJik's avatar
LysandreJik committed
214
    def forward(self, input):
215
216
217
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input):
VictorSanh's avatar
wip  
VictorSanh committed
218
219
220
221
222
223
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

224

VictorSanh's avatar
wip  
VictorSanh committed
225
class TransformerBlock(nn.Module):
LysandreJik's avatar
LysandreJik committed
226
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
227
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
228

VictorSanh's avatar
VictorSanh committed
229
        assert config.dim % config.n_heads == 0
VictorSanh's avatar
wip  
VictorSanh committed
230

VictorSanh's avatar
VictorSanh committed
231
        self.attention = MultiHeadSelfAttention(config)
VictorSanh's avatar
wip  
VictorSanh committed
232
233
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

VictorSanh's avatar
VictorSanh committed
234
        self.ffn = FFN(config)
VictorSanh's avatar
wip  
VictorSanh committed
235
236
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

237
    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
238
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
241
242
243
244
245
        Parameters:
            x: torch.tensor(bs, seq_length, dim)
            attn_mask: torch.tensor(bs, seq_length)

        Returns:
            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
VictorSanh's avatar
wip  
VictorSanh committed
246
247
        """
        # Self-Attention
248
        sa_output = self.attention(
Lysandre's avatar
Lysandre committed
249
250
251
252
253
254
            query=x,
            key=x,
            value=x,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
255
256
        )
        if output_attentions:
257
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
258
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
VictorSanh's avatar
VictorSanh committed
259
            assert type(sa_output) == tuple
VictorSanh's avatar
VictorSanh committed
260
            sa_output = sa_output[0]
261
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
262
263

        # Feed Forward Network
264
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
265
266
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

VictorSanh's avatar
VictorSanh committed
267
        output = (ffn_output,)
268
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
269
270
            output = (sa_weights,) + output
        return output
VictorSanh's avatar
wip  
VictorSanh committed
271

272

VictorSanh's avatar
wip  
VictorSanh committed
273
class Transformer(nn.Module):
LysandreJik's avatar
LysandreJik committed
274
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
275
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
276
277
        self.n_layers = config.n_layers

VictorSanh's avatar
VictorSanh committed
278
279
        layer = TransformerBlock(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
VictorSanh's avatar
wip  
VictorSanh committed
280

281
    def forward(
282
        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
Sylvain Gugger's avatar
Sylvain Gugger committed
283
    ):  # docstyle-ignore
VictorSanh's avatar
wip  
VictorSanh committed
284
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
285
286
287
288
289
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
290
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
Sylvain Gugger's avatar
Sylvain Gugger committed
291
292
293
294
295
296
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
VictorSanh's avatar
wip  
VictorSanh committed
297
        """
298
299
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
VictorSanh's avatar
wip  
VictorSanh committed
300

VictorSanh's avatar
VictorSanh committed
301
        hidden_state = x
302
        for i, layer_module in enumerate(self.layer):
Joseph Liu's avatar
Joseph Liu committed
303
            if output_hidden_states:
304
305
                all_hidden_states = all_hidden_states + (hidden_state,)

306
307
308
            layer_outputs = layer_module(
                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
            )
309
310
            hidden_state = layer_outputs[-1]

311
            if output_attentions:
312
313
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
VictorSanh's avatar
VictorSanh committed
314
                all_attentions = all_attentions + (attentions,)
315
316
317
318
            else:
                assert len(layer_outputs) == 1

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
319
        if output_hidden_states:
VictorSanh's avatar
VictorSanh committed
320
            all_hidden_states = all_hidden_states + (hidden_state,)
VictorSanh's avatar
wip  
VictorSanh committed
321

322
        if not return_dict:
323
324
325
326
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
VictorSanh's avatar
VictorSanh committed
327
328


329
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
thomwolf's avatar
thomwolf committed
330
class DistilBertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
331
332
333
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
VictorSanh's avatar
VictorSanh committed
334
    """
335

thomwolf's avatar
thomwolf committed
336
    config_class = DistilBertConfig
VictorSanh's avatar
VictorSanh committed
337
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
338
    base_model_prefix = "distilbert"
VictorSanh's avatar
VictorSanh committed
339

340
    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
341
        """Initialize the weights."""
VictorSanh's avatar
VictorSanh committed
342
        if isinstance(module, nn.Linear):
343
344
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
VictorSanh's avatar
VictorSanh committed
345
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
346
347
348
349
350
351
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
VictorSanh's avatar
VictorSanh committed
352
353
354
355
356
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


thomwolf's avatar
thomwolf committed
357
DISTILBERT_START_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
358

Sylvain Gugger's avatar
Sylvain Gugger committed
359
360
361
362
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
363
364
365
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
Lysandre's avatar
Lysandre committed
366

VictorSanh's avatar
VictorSanh committed
367
    Parameters:
368
        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
369
370
371
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
VictorSanh's avatar
VictorSanh committed
372
373
"""

thomwolf's avatar
thomwolf committed
374
DISTILBERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
375
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
376
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
377
378
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.
Lysandre's avatar
Lysandre committed
382

Lysandre's avatar
Lysandre committed
383
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
384
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
385
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
386
387

            - 1 for tokens that are **not masked**,
388
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
389

Lysandre's avatar
Lysandre committed
390
            `What are attention masks? <../glossary.html#attention-mask>`__
391
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
392
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
393
394
395
396
397

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Lysandre's avatar
Lysandre committed
398
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
399
400
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
401
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
402
403
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
404
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
405
406
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
407
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
408
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
VictorSanh's avatar
VictorSanh committed
409
410
"""

411
412
413
414
415

@add_start_docstrings(
    "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
416
class DistilBertModel(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
417
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
418
        super().__init__(config)
VictorSanh's avatar
VictorSanh committed
419

420
421
        self.embeddings = Embeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder
VictorSanh's avatar
VictorSanh committed
422

423
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
424

thomwolf's avatar
thomwolf committed
425
    def get_input_embeddings(self):
426
427
        return self.embeddings.word_embeddings

thomwolf's avatar
thomwolf committed
428
    def set_input_embeddings(self, new_embeddings):
429
430
        self.embeddings.word_embeddings = new_embeddings

431
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
432
433
434
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
435
436
437
438
        """
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

439
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
440
441
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
442
        checkpoint=_CHECKPOINT_FOR_DOC,
443
444
445
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
446
    def forward(
Joseph Liu's avatar
Joseph Liu committed
447
448
449
450
451
452
453
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
454
        return_dict=None,
455
456
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
457
458
459
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
460
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
461

462
463
464
465
466
467
468
469
470
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

Julien Chaumond's avatar
Julien Chaumond committed
471
472
        device = input_ids.device if input_ids is not None else inputs_embeds.device

VictorSanh's avatar
VictorSanh committed
473
        if attention_mask is None:
474
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
475

476
        # Prepare head mask if needed
477
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
478

479
        if inputs_embeds is None:
480
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
481
        return self.transformer(
Joseph Liu's avatar
Joseph Liu committed
482
483
484
485
486
            x=inputs_embeds,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
487
            return_dict=return_dict,
488
        )
VictorSanh's avatar
wip  
VictorSanh committed
489
490


491
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
492
493
    """DistilBert Model with a `masked language modeling` head on top. """,
    DISTILBERT_START_DOCSTRING,
494
)
thomwolf's avatar
thomwolf committed
495
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
496
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
497
        super().__init__(config)
VictorSanh's avatar
wip  
VictorSanh committed
498

thomwolf's avatar
thomwolf committed
499
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
500
501
502
503
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

504
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
505

LysandreJik's avatar
LysandreJik committed
506
        self.mlm_loss_fct = nn.CrossEntropyLoss()
VictorSanh's avatar
VictorSanh committed
507

thomwolf's avatar
thomwolf committed
508
    def get_output_embeddings(self):
509
        return self.vocab_projector
VictorSanh's avatar
VictorSanh committed
510

511
512
513
    def set_output_embeddings(self, new_embeddings):
        self.vocab_projector = new_embeddings

514
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
515
516
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
517
        checkpoint=_CHECKPOINT_FOR_DOC,
518
519
520
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
521
522
523
524
525
526
527
528
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
529
        output_hidden_states=None,
530
        return_dict=None,
531
    ):
Lysandre's avatar
Lysandre committed
532
        r"""
533
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
Lysandre's avatar
Lysandre committed
537
        """
538
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
539

540
        dlbrt_output = self.distilbert(
541
542
543
544
545
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
546
            output_hidden_states=output_hidden_states,
547
            return_dict=return_dict,
548
549
550
551
552
        )
        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
553
554
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

555
        mlm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
556
557
        if labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
VictorSanh's avatar
VictorSanh committed
558

559
        if not return_dict:
560
561
562
563
564
565
566
567
568
            output = (prediction_logits,) + dlbrt_output[1:]
            return ((mlm_loss,) + output) if mlm_loss is not None else output

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_logits,
            hidden_states=dlbrt_output.hidden_states,
            attentions=dlbrt_output.attentions,
        )
569

VictorSanh's avatar
VictorSanh committed
570

571
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
572
573
574
575
    """
    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
576
577
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
578
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
579
580
581
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
582
        self.config = config
Lysandre's avatar
Lysandre committed
583
584
585
586
587
588
589
590

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

591
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
592
593
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
594
        checkpoint=_CHECKPOINT_FOR_DOC,
595
596
597
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
598
599
600
601
602
603
604
605
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
606
        output_hidden_states=None,
607
        return_dict=None,
608
    ):
Lysandre's avatar
Lysandre committed
609
        r"""
610
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
611
612
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
Lysandre's avatar
Lysandre committed
613
614
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
615
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
616

617
        distilbert_output = self.distilbert(
618
619
620
621
622
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
623
            output_hidden_states=output_hidden_states,
624
            return_dict=return_dict,
625
626
627
628
629
630
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
631
        logits = self.classifier(pooled_output)  # (bs, num_labels)
VictorSanh's avatar
VictorSanh committed
632

633
        loss = None
VictorSanh's avatar
VictorSanh committed
634
        if labels is not None:
635
636
637
638
639
640
641
642
643
644
645
646
647
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
VictorSanh's avatar
VictorSanh committed
648
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
649
650
651
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
VictorSanh's avatar
VictorSanh committed
652

653
        if not return_dict:
654
655
656
657
658
659
660
661
662
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
VictorSanh's avatar
VictorSanh committed
663

664

665
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
666
667
668
669
    """
    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
670
671
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
672
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
673
674
675
676
677
678
679
680
681
682
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout)

        self.init_weights()

683
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
684
685
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
686
        checkpoint=_CHECKPOINT_FOR_DOC,
687
688
689
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Lysandre's avatar
Lysandre committed
690
691
692
693
694
695
696
697
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
698
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
699
        output_hidden_states=None,
700
        return_dict=None,
Lysandre's avatar
Lysandre committed
701
702
    ):
        r"""
703
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
704
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
705
706
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
707
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
708
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
709
710
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
711
        """
712
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
713

714
        distilbert_output = self.distilbert(
715
716
717
718
719
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
720
            output_hidden_states=output_hidden_states,
721
            return_dict=return_dict,
722
723
724
725
726
        )
        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)
VictorSanh's avatar
wip  
VictorSanh committed
727
        start_logits, end_logits = logits.split(1, dim=-1)
728
729
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
VictorSanh's avatar
wip  
VictorSanh committed
730

731
        total_loss = None
VictorSanh's avatar
wip  
VictorSanh committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

748
        if not return_dict:
749
750
751
752
753
754
755
756
757
758
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
759
760


761
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
762
763
764
765
    """
    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
766
767
    DISTILBERT_START_DOCSTRING,
)
768
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
769
770
771
772
773
774
775
776
777
778
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

779
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
780
781
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
782
        checkpoint=_CHECKPOINT_FOR_DOC,
783
784
785
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
786
787
788
789
790
791
792
793
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
794
        output_hidden_states=None,
795
        return_dict=None,
796
    ):
Lysandre's avatar
Lysandre committed
797
        r"""
798
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
799
800
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
Lysandre's avatar
Lysandre committed
801
        """
802
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803

804
        outputs = self.distilbert(
805
806
807
808
809
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
810
            output_hidden_states=output_hidden_states,
811
            return_dict=return_dict,
812
        )
813
814
815
816
817
818

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

819
        loss = None
820
821
822
823
824
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
825
826
827
828
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
829
830
831
832
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

833
        if not return_dict:
834
835
836
837
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
838
839
840
841
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
842
        )
843
844
845


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
846
847
848
849
    """
    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    a softmax) e.g. for RocStories/SWAG tasks.
    """,
850
851
852
853
854
855
856
857
858
859
860
861
862
    DISTILBERT_START_DOCSTRING,
)
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, 1)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

863
864
865
    @add_start_docstrings_to_model_forward(
        DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
866
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
867
868
869
870
871
872
873
874
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
875
        output_hidden_states=None,
876
        return_dict=None,
877
878
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
879
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
880
881
882
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
            :obj:`input_ids` above)
883

Lysandre's avatar
Lysandre committed
884
        Returns:
885

Lysandre's avatar
Lysandre committed
886
        Examples::
887

Lysandre's avatar
Lysandre committed
888
889
            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
            >>> import torch
890

Lysandre's avatar
Lysandre committed
891
            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
892
            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')
893

Lysandre's avatar
Lysandre committed
894
895
896
897
            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> choice0 = "It is eaten with a fork and a knife."
            >>> choice1 = "It is eaten while held in the hand."
            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1
898

Lysandre's avatar
Lysandre committed
899
900
            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)
            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
901

Lysandre's avatar
Lysandre committed
902
903
904
            >>> # the linear classifier still needs to be trained
            >>> loss = outputs.loss
            >>> logits = outputs.logits
905
        """
906
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        outputs = self.distilbert(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
923
            output_hidden_states=output_hidden_states,
924
            return_dict=return_dict,
925
926
927
928
929
930
931
932
933
934
935
        )

        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)
        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)
        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)

        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)

936
        loss = None
937
938
939
940
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

941
        if not return_dict:
942
943
944
945
            output = (reshaped_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
946
947
948
949
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
950
        )