modeling_distilbert.py 37.7 KB
Newer Older
VictorSanh's avatar
wip  
VictorSanh committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
VictorSanh's avatar
wip  
VictorSanh committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
17
"""
 PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
 part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
VictorSanh's avatar
wip  
VictorSanh committed
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

VictorSanh's avatar
wip  
VictorSanh committed
20

Aymeric Augustin's avatar
Aymeric Augustin committed
21
import copy
VictorSanh's avatar
wip  
VictorSanh committed
22
23
24
25
26
import math

import numpy as np
import torch
import torch.nn as nn
27
from torch.nn import CrossEntropyLoss
VictorSanh's avatar
wip  
VictorSanh committed
28

Sylvain Gugger's avatar
Sylvain Gugger committed
29
30
from ...activations import gelu
from ...file_utils import (
31
32
    add_code_sample_docstrings,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
35
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from ...modeling_outputs import (
37
38
39
40
41
42
43
    BaseModelOutput,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_distilbert import DistilBertConfig
VictorSanh's avatar
wip  
VictorSanh committed
52

53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
VictorSanh's avatar
wip  
VictorSanh committed
55

56
_CONFIG_FOR_DOC = "DistilBertConfig"
57
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
VictorSanh's avatar
wip  
VictorSanh committed
58

59
60
61
62
63
64
65
66
67
68
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "distilbert-base-uncased",
    "distilbert-base-uncased-distilled-squad",
    "distilbert-base-cased",
    "distilbert-base-cased-distilled-squad",
    "distilbert-base-german-cased",
    "distilbert-base-multilingual-cased",
    "distilbert-base-uncased-finetuned-sst-2-english",
    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert
]
VictorSanh's avatar
wip  
VictorSanh committed
69
70


71
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
VictorSanh's avatar
wip  
VictorSanh committed
72

73

VictorSanh's avatar
wip  
VictorSanh committed
74
def create_sinusoidal_embeddings(n_pos, dim, out):
75
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
76
    out.requires_grad = False
VictorSanh's avatar
wip  
VictorSanh committed
77
78
79
80
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()

81

VictorSanh's avatar
wip  
VictorSanh committed
82
class Embeddings(nn.Module):
83
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
84
        super().__init__()
85
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
VictorSanh's avatar
wip  
VictorSanh committed
86
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
VictorSanh's avatar
VictorSanh committed
87
        if config.sinusoidal_pos_embds:
88
89
90
            create_sinusoidal_embeddings(
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
            )
VictorSanh's avatar
wip  
VictorSanh committed
91
92
93
94
95
96

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
        Parameters:
            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.

        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
        embeddings)
VictorSanh's avatar
wip  
VictorSanh committed
102
103
        """
        seq_length = input_ids.size(1)
104
105
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
106

107
108
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
109

VictorSanh's avatar
VictorSanh committed
110
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
111
112
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
113
114
        return embeddings

115

VictorSanh's avatar
wip  
VictorSanh committed
116
class MultiHeadSelfAttention(nn.Module):
LysandreJik's avatar
LysandreJik committed
117
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
118
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
119
120
121
122
123
124
125

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        assert self.dim % self.n_heads == 0

VictorSanh's avatar
VictorSanh committed
126
127
128
129
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
VictorSanh's avatar
wip  
VictorSanh committed
130

131
132
        self.pruned_heads = set()

133
134
135
136
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
137
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
138
139
140
141
142
143
144
145
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
146
        self.pruned_heads = self.pruned_heads.union(heads)
147

148
    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
149
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
150
151
152
153
154
155
156
157
158
        Parameters:
            query: torch.tensor(bs, seq_length, dim)
            key: torch.tensor(bs, seq_length, dim)
            value: torch.tensor(bs, seq_length, dim)
            mask: torch.tensor(bs, seq_length)

        Returns:
            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
VictorSanh's avatar
wip  
VictorSanh committed
159
160
161
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
162
163
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()
VictorSanh's avatar
wip  
VictorSanh committed
164

165
        dim_per_head = self.dim // self.n_heads
VictorSanh's avatar
wip  
VictorSanh committed
166
167
168
169
170
171
172
173
174

        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
175
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
176

177
178
179
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
VictorSanh's avatar
wip  
VictorSanh committed
180

181
182
183
184
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
VictorSanh's avatar
wip  
VictorSanh committed
185

186
187
        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
188
189
190
191
192

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

193
194
195
        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
196

197
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
198
            return (context, weights)
VictorSanh's avatar
wip  
VictorSanh committed
199
        else:
VictorSanh's avatar
VictorSanh committed
200
            return (context,)
VictorSanh's avatar
wip  
VictorSanh committed
201

202

VictorSanh's avatar
wip  
VictorSanh committed
203
class FFN(nn.Module):
LysandreJik's avatar
LysandreJik committed
204
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
205
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
206
        self.dropout = nn.Dropout(p=config.dropout)
207
208
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
VictorSanh's avatar
wip  
VictorSanh committed
209
210
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
211
212
213
214
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = gelu if config.activation == "gelu" else nn.ReLU()
VictorSanh's avatar
wip  
VictorSanh committed
215

LysandreJik's avatar
LysandreJik committed
216
    def forward(self, input):
217
218
219
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input):
VictorSanh's avatar
wip  
VictorSanh committed
220
221
222
223
224
225
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

226

VictorSanh's avatar
wip  
VictorSanh committed
227
class TransformerBlock(nn.Module):
LysandreJik's avatar
LysandreJik committed
228
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
229
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
230

VictorSanh's avatar
VictorSanh committed
231
        assert config.dim % config.n_heads == 0
VictorSanh's avatar
wip  
VictorSanh committed
232

VictorSanh's avatar
VictorSanh committed
233
        self.attention = MultiHeadSelfAttention(config)
VictorSanh's avatar
wip  
VictorSanh committed
234
235
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

VictorSanh's avatar
VictorSanh committed
236
        self.ffn = FFN(config)
VictorSanh's avatar
wip  
VictorSanh committed
237
238
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

239
    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
VictorSanh's avatar
wip  
VictorSanh committed
240
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
241
242
243
244
245
246
247
        Parameters:
            x: torch.tensor(bs, seq_length, dim)
            attn_mask: torch.tensor(bs, seq_length)

        Returns:
            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
VictorSanh's avatar
wip  
VictorSanh committed
248
249
        """
        # Self-Attention
250
        sa_output = self.attention(
Lysandre's avatar
Lysandre committed
251
252
253
254
255
256
            query=x,
            key=x,
            value=x,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
257
258
        )
        if output_attentions:
259
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
260
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
VictorSanh's avatar
VictorSanh committed
261
            assert type(sa_output) == tuple
VictorSanh's avatar
VictorSanh committed
262
            sa_output = sa_output[0]
263
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
264
265

        # Feed Forward Network
266
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
267
268
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

VictorSanh's avatar
VictorSanh committed
269
        output = (ffn_output,)
270
        if output_attentions:
VictorSanh's avatar
VictorSanh committed
271
272
            output = (sa_weights,) + output
        return output
VictorSanh's avatar
wip  
VictorSanh committed
273

274

VictorSanh's avatar
wip  
VictorSanh committed
275
class Transformer(nn.Module):
LysandreJik's avatar
LysandreJik committed
276
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
277
        super().__init__()
VictorSanh's avatar
wip  
VictorSanh committed
278
279
        self.n_layers = config.n_layers

VictorSanh's avatar
VictorSanh committed
280
281
        layer = TransformerBlock(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
VictorSanh's avatar
wip  
VictorSanh committed
282

283
    def forward(
284
        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
Sylvain Gugger's avatar
Sylvain Gugger committed
285
    ):  # docstyle-ignore
VictorSanh's avatar
wip  
VictorSanh committed
286
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
289
290
291
        Parameters:
            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.

        Returns:
292
            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
Sylvain Gugger's avatar
Sylvain Gugger committed
293
294
295
296
297
298
            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
                Tuple of length n_layers with the hidden states from each layer.
                Optional: only if output_hidden_states=True
            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
                Tuple of length n_layers with the attention weights from each layer
                Optional: only if output_attentions=True
VictorSanh's avatar
wip  
VictorSanh committed
299
        """
300
301
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
VictorSanh's avatar
wip  
VictorSanh committed
302

VictorSanh's avatar
VictorSanh committed
303
        hidden_state = x
304
        for i, layer_module in enumerate(self.layer):
Joseph Liu's avatar
Joseph Liu committed
305
            if output_hidden_states:
306
307
                all_hidden_states = all_hidden_states + (hidden_state,)

308
309
310
            layer_outputs = layer_module(
                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
            )
311
312
            hidden_state = layer_outputs[-1]

313
            if output_attentions:
314
315
                assert len(layer_outputs) == 2
                attentions = layer_outputs[0]
VictorSanh's avatar
VictorSanh committed
316
                all_attentions = all_attentions + (attentions,)
317
318
319
320
            else:
                assert len(layer_outputs) == 1

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
321
        if output_hidden_states:
VictorSanh's avatar
VictorSanh committed
322
            all_hidden_states = all_hidden_states + (hidden_state,)
VictorSanh's avatar
wip  
VictorSanh committed
323

324
        if not return_dict:
325
326
327
328
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
VictorSanh's avatar
VictorSanh committed
329
330


331
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
thomwolf's avatar
thomwolf committed
332
class DistilBertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
335
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
VictorSanh's avatar
VictorSanh committed
336
    """
337

thomwolf's avatar
thomwolf committed
338
    config_class = DistilBertConfig
VictorSanh's avatar
VictorSanh committed
339
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
340
    base_model_prefix = "distilbert"
VictorSanh's avatar
VictorSanh committed
341

342
    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
343
        """Initialize the weights."""
VictorSanh's avatar
VictorSanh committed
344
        if isinstance(module, nn.Linear):
345
346
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
VictorSanh's avatar
VictorSanh committed
347
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
348
349
350
351
352
353
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
VictorSanh's avatar
VictorSanh committed
354
355
356
357
358
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


thomwolf's avatar
thomwolf committed
359
DISTILBERT_START_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
360

Sylvain Gugger's avatar
Sylvain Gugger committed
361
362
363
364
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
365
366
367
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
Lysandre's avatar
Lysandre committed
368

VictorSanh's avatar
VictorSanh committed
369
    Parameters:
370
        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
371
372
373
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
VictorSanh's avatar
VictorSanh committed
374
375
"""

thomwolf's avatar
thomwolf committed
376
DISTILBERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
377
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
378
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
379
380
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
381
382
383
            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.
Lysandre's avatar
Lysandre committed
384

Lysandre's avatar
Lysandre committed
385
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
386
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
387
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
388
389

            - 1 for tokens that are **not masked**,
390
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
391

Lysandre's avatar
Lysandre committed
392
            `What are attention masks? <../glossary.html#attention-mask>`__
393
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
394
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
395
396
397
398
399

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Lysandre's avatar
Lysandre committed
400
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
401
402
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
403
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
404
405
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
406
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
407
408
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
409
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
410
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
VictorSanh's avatar
VictorSanh committed
411
412
"""

413
414
415
416
417

@add_start_docstrings(
    "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
418
class DistilBertModel(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
419
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
420
        super().__init__(config)
VictorSanh's avatar
VictorSanh committed
421

422
423
        self.embeddings = Embeddings(config)  # Embeddings
        self.transformer = Transformer(config)  # Encoder
VictorSanh's avatar
VictorSanh committed
424

425
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
426

thomwolf's avatar
thomwolf committed
427
    def get_input_embeddings(self):
428
429
        return self.embeddings.word_embeddings

thomwolf's avatar
thomwolf committed
430
    def set_input_embeddings(self, new_embeddings):
431
432
        self.embeddings.word_embeddings = new_embeddings

433
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
434
435
436
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
437
438
439
440
        """
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

441
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
442
443
444
445
446
447
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
448
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
449
    def forward(
Joseph Liu's avatar
Joseph Liu committed
450
451
452
453
454
455
456
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
457
        return_dict=None,
458
459
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
460
461
462
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
463
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
464

465
466
467
468
469
470
471
472
473
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

Julien Chaumond's avatar
Julien Chaumond committed
474
475
        device = input_ids.device if input_ids is not None else inputs_embeds.device

VictorSanh's avatar
VictorSanh committed
476
        if attention_mask is None:
477
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
478

479
        # Prepare head mask if needed
480
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
481

482
        if inputs_embeds is None:
483
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
484
        return self.transformer(
Joseph Liu's avatar
Joseph Liu committed
485
486
487
488
489
            x=inputs_embeds,
            attn_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
490
            return_dict=return_dict,
491
        )
VictorSanh's avatar
wip  
VictorSanh committed
492
493


494
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
495
496
    """DistilBert Model with a `masked language modeling` head on top. """,
    DISTILBERT_START_DOCSTRING,
497
)
thomwolf's avatar
thomwolf committed
498
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
VictorSanh's avatar
VictorSanh committed
499
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
500
        super().__init__(config)
VictorSanh's avatar
wip  
VictorSanh committed
501

thomwolf's avatar
thomwolf committed
502
        self.distilbert = DistilBertModel(config)
VictorSanh's avatar
VictorSanh committed
503
504
505
506
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

507
        self.init_weights()
VictorSanh's avatar
VictorSanh committed
508

LysandreJik's avatar
LysandreJik committed
509
        self.mlm_loss_fct = nn.CrossEntropyLoss()
VictorSanh's avatar
VictorSanh committed
510

thomwolf's avatar
thomwolf committed
511
    def get_output_embeddings(self):
512
        return self.vocab_projector
VictorSanh's avatar
VictorSanh committed
513

514
515
516
    def set_output_embeddings(self, new_embeddings):
        self.vocab_projector = new_embeddings

517
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
518
519
520
521
522
523
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
524
525
526
527
528
529
530
531
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
532
        output_hidden_states=None,
533
        return_dict=None,
534
    ):
Lysandre's avatar
Lysandre committed
535
        r"""
536
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
537
538
539
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
Lysandre's avatar
Lysandre committed
540
        """
541
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
542

543
        dlbrt_output = self.distilbert(
544
545
546
547
548
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
549
            output_hidden_states=output_hidden_states,
550
            return_dict=return_dict,
551
552
553
554
555
        )
        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
556
557
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

558
        mlm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
559
560
        if labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
VictorSanh's avatar
VictorSanh committed
561

562
        if not return_dict:
563
564
565
566
567
568
569
570
571
            output = (prediction_logits,) + dlbrt_output[1:]
            return ((mlm_loss,) + output) if mlm_loss is not None else output

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_logits,
            hidden_states=dlbrt_output.hidden_states,
            attentions=dlbrt_output.attentions,
        )
572

VictorSanh's avatar
VictorSanh committed
573

574
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
575
576
577
578
    """
    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    """,
579
580
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
581
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
582
583
584
585
586
587
588
589
590
591
592
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

593
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
594
595
596
597
598
599
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
600
601
602
603
604
605
606
607
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
608
        output_hidden_states=None,
609
        return_dict=None,
610
    ):
Lysandre's avatar
Lysandre committed
611
        r"""
612
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
613
614
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
Lysandre's avatar
Lysandre committed
615
616
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
617
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
618

619
        distilbert_output = self.distilbert(
620
621
622
623
624
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
625
            output_hidden_states=output_hidden_states,
626
            return_dict=return_dict,
627
628
629
630
631
632
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
633
        logits = self.classifier(pooled_output)  # (bs, num_labels)
VictorSanh's avatar
VictorSanh committed
634

635
        loss = None
VictorSanh's avatar
VictorSanh committed
636
637
638
639
640
641
642
643
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

644
        if not return_dict:
645
646
647
648
649
650
651
652
653
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
VictorSanh's avatar
VictorSanh committed
654

655

656
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
657
658
659
660
    """
    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
661
662
    DISTILBERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
663
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
664
665
666
667
668
669
670
671
672
673
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout)

        self.init_weights()

674
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
675
676
677
678
679
680
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Lysandre's avatar
Lysandre committed
681
682
683
684
685
686
687
688
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
689
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
690
        output_hidden_states=None,
691
        return_dict=None,
Lysandre's avatar
Lysandre committed
692
693
    ):
        r"""
694
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
695
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
698
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
VictorSanh's avatar
VictorSanh committed
699
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
700
701
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
702
        """
703
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
704

705
        distilbert_output = self.distilbert(
706
707
708
709
710
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
711
            output_hidden_states=output_hidden_states,
712
            return_dict=return_dict,
713
714
715
716
717
        )
        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)
VictorSanh's avatar
wip  
VictorSanh committed
718
        start_logits, end_logits = logits.split(1, dim=-1)
719
720
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
VictorSanh's avatar
wip  
VictorSanh committed
721

722
        total_loss = None
VictorSanh's avatar
wip  
VictorSanh committed
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

739
        if not return_dict:
740
741
742
743
744
745
746
747
748
749
            output = (start_logits, end_logits) + distilbert_output[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )
750
751


752
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
753
754
755
756
    """
    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
    for Named-Entity-Recognition (NER) tasks.
    """,
757
758
    DISTILBERT_START_DOCSTRING,
)
759
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
760
761
762
763
764
765
766
767
768
769
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.distilbert = DistilBertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

770
    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
771
772
773
774
775
776
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="distilbert-base-uncased",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
777
778
779
780
781
782
783
784
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
785
        output_hidden_states=None,
786
        return_dict=None,
787
    ):
Lysandre's avatar
Lysandre committed
788
        r"""
789
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
790
791
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
Lysandre's avatar
Lysandre committed
792
        """
793
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
794

795
        outputs = self.distilbert(
796
797
798
799
800
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
801
            output_hidden_states=output_hidden_states,
802
            return_dict=return_dict,
803
        )
804
805
806
807
808
809

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

810
        loss = None
811
812
813
814
815
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
816
817
818
819
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
820
821
822
823
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

824
        if not return_dict:
825
826
827
828
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
829
830
831
832
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
833
        )
834
835
836


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
837
838
839
840
    """
    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
    a softmax) e.g. for RocStories/SWAG tasks.
    """,
841
842
843
844
845
846
847
848
849
850
851
852
853
    DISTILBERT_START_DOCSTRING,
)
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, 1)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.init_weights()

854
855
856
    @add_start_docstrings_to_model_forward(
        DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
    )
857
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
858
859
860
861
862
863
864
865
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
866
        output_hidden_states=None,
867
        return_dict=None,
868
869
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
870
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
871
872
873
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
            :obj:`input_ids` above)
874

Lysandre's avatar
Lysandre committed
875
        Returns:
876

Lysandre's avatar
Lysandre committed
877
        Examples::
878

Lysandre's avatar
Lysandre committed
879
880
            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
            >>> import torch
881

Lysandre's avatar
Lysandre committed
882
            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
883
            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')
884

Lysandre's avatar
Lysandre committed
885
886
887
888
            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> choice0 = "It is eaten with a fork and a knife."
            >>> choice1 = "It is eaten while held in the hand."
            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1
889

Lysandre's avatar
Lysandre committed
890
891
            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)
            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
892

Lysandre's avatar
Lysandre committed
893
894
895
            >>> # the linear classifier still needs to be trained
            >>> loss = outputs.loss
            >>> logits = outputs.logits
896
        """
897
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )

        outputs = self.distilbert(
            input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
914
            output_hidden_states=output_hidden_states,
915
            return_dict=return_dict,
916
917
918
919
920
921
922
923
924
925
926
        )

        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)
        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)
        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)

        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)

927
        loss = None
928
929
930
931
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

932
        if not return_dict:
933
934
935
936
            output = (reshaped_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
937
938
939
940
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
941
        )