modeling_distilbert.py 37.5 KB
Newer Older
VictorSanh's avatar
wip  
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
VictorSanh's avatar
wip  
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
17
"""
 PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
 part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
VictorSanh's avatar
wip  
VictorSanh committed
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

VictorSanh's avatar
wip  
VictorSanh committed
20

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import copy
VictorSanh's avatar
wip  
VictorSanh committed
22
23
24
25
26
import math

import numpy as np
import torch
import torch.nn as nn
27
from torch.nn import CrossEntropyLoss
VictorSanh's avatar
wip  
VictorSanh committed
28

Sylvain Gugger's avatar
Sylvain Gugger committed
29
30
from ...activations import gelu
from ...file_utils import (
31
32
    add_code_sample_docstrings,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
35
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from ...modeling_outputs import (
37
38
39
40
41
42
43
    BaseModelOutput,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
VictorSanh's avatar
wip  
VictorSanh committed
52

53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
VictorSanh's avatar
wip  
VictorSanh committed
55

56
_CONFIG_FOR_DOC = "DistilBertConfig"
57
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
VictorSanh's avatar
wip  
VictorSanh committed
58

59
60
61
62
63
64
65
66
67
68
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "distilbert-base-uncased",
    "distilbert-base-uncased-distilled-squad",
    "distilbert-base-cased",
    "distilbert-base-cased-distilled-squad",
    "distilbert-base-german-cased",
    "distilbert-base-multilingual-cased",
    "distilbert-base-uncased-finetuned-sst-2-english",
    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert
]
VictorSanh's avatar
wip  
VictorSanh committed
69
70


71
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
VictorSanh's avatar
wip  
VictorSanh committed
72

73

VictorSanh's avatar
wip  
VictorSanh committed
74
def create_sinusoidal_embeddings(n_pos, dim, out):
75
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
76
    out.requires_grad = False
VictorSanh's avatar
wip  
VictorSanh committed
77
78
79
80
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()

81

VictorSanh's avatar
wip  
VictorSanh committed
82
class Embeddings(nn.Module):
83
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
84
        super().__init__()
85
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
VictorSanh's avatar
wip  
VictorSanh committed
86
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
VictorSanh's avatar
VictorSanh committed
87
        if config.sinusoidal_pos_embds:
88
89
90
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )
VictorSanh's avatar
wip  
VictorSanh committed
91
92
93
94
95
96

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
        Parameters:
            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.

        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
        embeddings)
VictorSanh's avatar
wip  
VictorSanh committed
102
103
        """
        seq_length = input_ids.size(1)
104
105
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
106

107
108
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
109

VictorSanh's avatar
VictorSanh committed
110
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
111
112
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
113
114
        return embeddings

115

VictorSanh's avatar
wip  
VictorSanh committed
116
class MultiHeadSelfAttention(nn.Module):
LysandreJik's avatar
LysandreJik committed
117
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
118
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
119
120
121
122
123
124
125

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

VictorSanh's avatar
VictorSanh committed
126
127
128
129
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
VictorSanh's avatar
wip  
VictorSanh committed
130

131
132
        self.pruned_heads = set()

133
134
135
136
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
137
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
138
139
140
141
142
143
144
145
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
146
        self.pruned_heads = self.pruned_heads.union(heads)
147

148
    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
149
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
152
153
154
155
156
157
158
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
VictorSanh's avatar
wip  
VictorSanh committed
159
160
161
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
162
163
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()
VictorSanh's avatar
wip  
VictorSanh committed
164

165
        dim_per_head = self.dim // self.n_heads
VictorSanh's avatar
wip  
VictorSanh committed
166
167
168
169
170
171
172
173
174

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
175
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
176

177
178
179
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
180

181
182
183
184
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
VictorSanh's avatar
wip  
VictorSanh committed
185

186
187
        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
188
189
190
191
192

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

193
194
195
        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
196

197
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
198
            return (context, weights)
VictorSanh's avatar
wip  
VictorSanh committed
199
        else:
VictorSanh's avatar
VictorSanh committed
200
            return (context,)
VictorSanh's avatar
wip  
VictorSanh committed
201

202

VictorSanh's avatar
wip  
VictorSanh committed
203
class FFN(nn.Module):
LysandreJik's avatar
LysandreJik committed
204
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
205
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
206
        self.dropout = nn.Dropout(p=config.dropout)
207
208
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
VictorSanh's avatar
wip  
VictorSanh committed
209
210
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
211
212
213
214
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = gelu if config.activation == "gelu" else nn.ReLU()
VictorSanh's avatar
wip  
VictorSanh committed
215

LysandreJik's avatar
LysandreJik committed
216
    def forward(self, input):
217
218
219
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input):
VictorSanh's avatar
wip  
VictorSanh committed
220
221
222
223
224
225
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

226

VictorSanh's avatar
wip  
VictorSanh committed
227
class TransformerBlock(nn.Module):
LysandreJik's avatar
LysandreJik committed
228
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
229
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
230

VictorSanh's avatar
VictorSanh committed
231
        assert config.dim % config.n_heads == 0
VictorSanh's avatar
wip  
VictorSanh committed
232

VictorSanh's avatar
VictorSanh committed
233
        self.attention = MultiHeadSelfAttention(config)
VictorSanh's avatar
wip  
VictorSanh committed
234
235
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

VictorSanh's avatar
VictorSanh committed
236
        self.ffn = FFN(config)
VictorSanh's avatar
wip  
VictorSanh committed
237
238
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

239
    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
240
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
241
242
243
244
245
246
247
        Parameters:
            x: torch.tensor(bs, seq_length, dim)
            attn_mask: torch.tensor(bs, seq_length)

        Returns:
            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
VictorSanh's avatar
wip  
VictorSanh committed
248
249
        """
        # Self-Attention
250
        sa_output = self.attention(
Lysandre's avatar
Lysandre committed
251
252
253
254
255
256
            query=x,
            key=x,
            value=x,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
257
258
        )
        if output_attentions:
259
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
260
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
VictorSanh's avatar
VictorSanh committed
261
            assert type(sa_output) == tuple
VictorSanh's avatar
VictorSanh committed
262
            sa_output = sa_output[0]
263
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
264
265

        # Feed Forward Network
266
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
267
268
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

VictorSanh's avatar
VictorSanh committed
269
        output = (ffn_output,)
270
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
271
272
            output = (sa_weights,) + output
        return output
VictorSanh's avatar
wip  
VictorSanh committed
273

274

VictorSanh's avatar
wip  
VictorSanh committed
275
class Transformer(nn.Module):
LysandreJik's avatar
LysandreJik committed
276
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
277
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
278
279
        self.n_layers = config.n_layers

VictorSanh's avatar
VictorSanh committed
280
281
        layer = TransformerBlock(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
VictorSanh's avatar
wip  
VictorSanh committed
282

283
    def forward(
284
        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
Sylvain Gugger's avatar
Sylvain Gugger committed
285
    ):  # docstyle-ignore
VictorSanh's avatar
wip  
VictorSanh committed
286
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
289
290
291
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
292
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
Sylvain Gugger's avatar
Sylvain Gugger committed
293
294
295
296
297
298
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
VictorSanh's avatar
wip  
VictorSanh committed
299
        """
300
301
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
VictorSanh's avatar
wip  
VictorSanh committed
302

VictorSanh's avatar
VictorSanh committed
303
        hidden_state = x
304
        for i, layer_module in enumerate(self.layer):
Joseph Liu's avatar
Joseph Liu committed
305
            if output_hidden_states:
306
307
                all_hidden_states = all_hidden_states + (hidden_state,)

308
309
310
            layer_outputs = layer_module(
                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
            )
311
312
            hidden_state = layer_outputs[-1]

313
            if output_attentions:
314
315
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
VictorSanh's avatar
VictorSanh committed
316
                all_attentions = all_attentions + (attentions,)
317
318
319
320
            else:
                assert len(layer_outputs) == 1

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
321
        if output_hidden_states:
VictorSanh's avatar
VictorSanh committed
322
            all_hidden_states = all_hidden_states + (hidden_state,)
VictorSanh's avatar
wip  
VictorSanh committed
323

324
        if not return_dict:
325
326
327
328
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
VictorSanh's avatar
VictorSanh committed
329
330


331
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
thomwolf's avatar
thomwolf committed
332
class DistilBertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
335
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
VictorSanh's avatar
VictorSanh committed
336
    """
337

thomwolf's avatar
thomwolf committed
338
    config_class = DistilBertConfig
VictorSanh's avatar
VictorSanh committed
339
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
340
    base_model_prefix = "distilbert"
VictorSanh's avatar
VictorSanh committed
341

342
    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
343
        """Initialize the weights."""
VictorSanh's avatar
VictorSanh committed
344
345
346
347
348
349
350
351
352
353
354
355
        if isinstance(module, nn.Embedding):
            if module.weight.requires_grad:
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


thomwolf's avatar
thomwolf committed
356
DISTILBERT_START_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
357

Sylvain Gugger's avatar
Sylvain Gugger committed
358
359
360
361
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
362
363
364
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
Lysandre's avatar
Lysandre committed
365

VictorSanh's avatar
VictorSanh committed
366
    Parameters:
367
        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
368
369
370
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
VictorSanh's avatar
VictorSanh committed
371
372
"""

thomwolf's avatar
thomwolf committed
373
DISTILBERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
374
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
375
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
376
377
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
378
379
380
            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.
Lysandre's avatar
Lysandre committed
381

Lysandre's avatar
Lysandre committed
382
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
383
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
384
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
385
386

            - 1 for tokens that are **not masked**,
387
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
388

Lysandre's avatar
Lysandre committed
389
            `What are attention masks? <../glossary.html#attention-mask>`__
390
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
391
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
392
393
394
395
396

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Lysandre's avatar
Lysandre committed
397
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
398
399
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
400
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
401
402
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
403
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
404
405
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
406
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
407
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
VictorSanh's avatar
VictorSanh committed
408
409
"""

410
411
412
413
414

@add_start_docstrings(
    "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
415
class DistilBertModel(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
416
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
417
        super().__init__(config)
VictorSanh's avatar
VictorSanh committed
418

419
420
        self.embeddings = Embeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder
VictorSanh's avatar
VictorSanh committed
421

422
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
423

thomwolf's avatar
thomwolf committed
424
    def get_input_embeddings(self):
425
426
        return self.embeddings.word_embeddings

thomwolf's avatar
thomwolf committed
427
    def set_input_embeddings(self, new_embeddings):
428
429
        self.embeddings.word_embeddings = new_embeddings

430
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
431
432
433
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
434
435
436
437
        """
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

438
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
439
440
441
442
443
444
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
445
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
446
    def forward(
Joseph Liu's avatar
Joseph Liu committed
447
448
449
450
451
452
453
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
454
        return_dict=None,
455
456
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
457
458
459
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
460
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
461

462
463
464
465
466
467
468
469
470
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

Julien Chaumond's avatar
Julien Chaumond committed
471
472
        device = input_ids.device if input_ids is not None else inputs_embeds.device

VictorSanh's avatar
VictorSanh committed
473
        if attention_mask is None:
474
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
475

476
        # Prepare head mask if needed
477
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
478

479
        if inputs_embeds is None:
480
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
481
        return self.transformer(
Joseph Liu's avatar
Joseph Liu committed
482
483
484
485
486
            x=inputs_embeds,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
487
            return_dict=return_dict,
488
        )
VictorSanh's avatar
wip  
VictorSanh committed
489
490


491
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
492
493
    """DistilBert Model with a `masked language modeling` head on top. """,
    DISTILBERT_START_DOCSTRING,
494
)
thomwolf's avatar
thomwolf committed
495
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
496
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
497
        super().__init__(config)
VictorSanh's avatar
wip  
VictorSanh committed
498

thomwolf's avatar
thomwolf committed
499
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
500
501
502
503
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

504
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
505

LysandreJik's avatar
LysandreJik committed
506
        self.mlm_loss_fct = nn.CrossEntropyLoss()
VictorSanh's avatar
VictorSanh committed
507

thomwolf's avatar
thomwolf committed
508
    def get_output_embeddings(self):
509
        return self.vocab_projector
VictorSanh's avatar
VictorSanh committed
510

511
512
513
    def set_output_embeddings(self, new_embeddings):
        self.vocab_projector = new_embeddings

514
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
515
516
517
518
519
520
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
521
522
523
524
525
526
527
528
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
529
        output_hidden_states=None,
530
        return_dict=None,
531
    ):
Lysandre's avatar
Lysandre committed
532
        r"""
533
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
Lysandre's avatar
Lysandre committed
537
        """
538
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
539

540
        dlbrt_output = self.distilbert(
541
542
543
544
545
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
546
            output_hidden_states=output_hidden_states,
547
            return_dict=return_dict,
548
549
550
551
552
        )
        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
553
554
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

555
        mlm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
556
557
        if labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
VictorSanh's avatar
VictorSanh committed
558

559
        if not return_dict:
560
561
562
563
564
565
566
567
568
            output = (prediction_logits,) + dlbrt_output[1:]
            return ((mlm_loss,) + output) if mlm_loss is not None else output

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_logits,
            hidden_states=dlbrt_output.hidden_states,
            attentions=dlbrt_output.attentions,
        )
569

VictorSanh's avatar
VictorSanh committed
570

571
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
572
573
574
575
    """
    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
576
577
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
578
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
579
580
581
582
583
584
585
586
587
588
589
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

590
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
591
592
593
594
595
596
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
597
598
599
600
601
602
603
604
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
605
        output_hidden_states=None,
606
        return_dict=None,
607
    ):
Lysandre's avatar
Lysandre committed
608
        r"""
609
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
610
611
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
Lysandre's avatar
Lysandre committed
612
613
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
614
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
615

616
        distilbert_output = self.distilbert(
617
618
619
620
621
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
622
            output_hidden_states=output_hidden_states,
623
            return_dict=return_dict,
624
625
626
627
628
629
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
630
        logits = self.classifier(pooled_output)  # (bs, num_labels)
VictorSanh's avatar
VictorSanh committed
631

632
        loss = None
VictorSanh's avatar
VictorSanh committed
633
634
635
636
637
638
639
640
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

641
        if not return_dict:
642
643
644
645
646
647
648
649
650
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
VictorSanh's avatar
VictorSanh committed
651

652

653
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
654
655
656
657
    """
    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
658
659
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
660
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
661
662
663
664
665
666
667
668
669
670
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout)

        self.init_weights()

671
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
672
673
674
675
676
677
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Lysandre's avatar
Lysandre committed
678
679
680
681
682
683
684
685
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
686
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
687
        output_hidden_states=None,
688
        return_dict=None,
Lysandre's avatar
Lysandre committed
689
690
    ):
        r"""
691
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
692
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
695
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
696
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
697
698
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
699
        """
700
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
701

702
        distilbert_output = self.distilbert(
703
704
705
706
707
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
708
            output_hidden_states=output_hidden_states,
709
            return_dict=return_dict,
710
711
712
713
714
        )
        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)
VictorSanh's avatar
wip  
VictorSanh committed
715
        start_logits, end_logits = logits.split(1, dim=-1)
716
717
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
VictorSanh's avatar
wip  
VictorSanh committed
718

719
        total_loss = None
VictorSanh's avatar
wip  
VictorSanh committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

736
        if not return_dict:
737
738
739
740
741
742
743
744
745
746
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
747
748


749
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
750
751
752
753
    """
    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
754
755
    DISTILBERT_START_DOCSTRING,
)
756
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
757
758
759
760
761
762
763
764
765
766
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

767
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
768
769
770
771
772
773
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
774
775
776
777
778
779
780
781
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
782
        output_hidden_states=None,
783
        return_dict=None,
784
    ):
Lysandre's avatar
Lysandre committed
785
        r"""
786
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
787
788
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
Lysandre's avatar
Lysandre committed
789
        """
790
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
791

792
        outputs = self.distilbert(
793
794
795
796
797
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
798
            output_hidden_states=output_hidden_states,
799
            return_dict=return_dict,
800
        )
801
802
803
804
805
806

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

807
        loss = None
808
809
810
811
812
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
813
814
815
816
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
817
818
819
820
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

821
        if not return_dict:
822
823
824
825
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
826
827
828
829
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
830
        )
831
832
833


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
834
835
836
837
    """
    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    a softmax) e.g. for RocStories/SWAG tasks.
    """,
838
839
840
841
842
843
844
845
846
847
848
849
850
    DISTILBERT_START_DOCSTRING,
)
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, 1)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

851
852
853
    @add_start_docstrings_to_model_forward(
        DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
854
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
855
856
857
858
859
860
861
862
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
863
        output_hidden_states=None,
864
        return_dict=None,
865
866
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
867
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
868
869
870
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
            :obj:`input_ids` above)
871

Lysandre's avatar
Lysandre committed
872
        Returns:
873

Lysandre's avatar
Lysandre committed
874
        Examples::
875

Lysandre's avatar
Lysandre committed
876
877
            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
            >>> import torch
878

Lysandre's avatar
Lysandre committed
879
            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
880
            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')
881

Lysandre's avatar
Lysandre committed
882
883
884
885
            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> choice0 = "It is eaten with a fork and a knife."
            >>> choice1 = "It is eaten while held in the hand."
            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1
886

Lysandre's avatar
Lysandre committed
887
888
            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)
            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
889

Lysandre's avatar
Lysandre committed
890
891
892
            >>> # the linear classifier still needs to be trained
            >>> loss = outputs.loss
            >>> logits = outputs.logits
893
        """
894
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        outputs = self.distilbert(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
911
            output_hidden_states=output_hidden_states,
912
            return_dict=return_dict,
913
914
915
916
917
918
919
920
921
922
923
        )

        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)
        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)
        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)

        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)

924
        loss = None
925
926
927
928
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

929
        if not return_dict:
930
931
932
933
            output = (reshaped_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
934
935
936
937
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
938
        )