modeling_dilbert.py 28 KB
Newer Older
VictorSanh's avatar
wip  
VictorSanh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch DilBERT model.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
VictorSanh's avatar
VictorSanh committed
23
import copy
VictorSanh's avatar
wip  
VictorSanh committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
from io import open

import itertools
import numpy as np

import torch
import torch.nn as nn

from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings

import logging
logger = logging.getLogger(__name__)


DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
    'dilbert-base-uncased': None, # TODO(Victor)
}

DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'dilbert-base-uncased': None, #TODO(Victor)
}


class DilBertconfig(PretrainedConfig):
    pretrained_config_archive_map = DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP

    def __init__(self,
                 vocab_size_or_config_json_file=30522,
                 max_position_embeddings=512,
                 sinusoidal_pos_embds=True,
                 n_layers=6,
                 n_heads=12,
                 dim=768,
VictorSanh's avatar
VictorSanh committed
58
                 hidden_dim=4*768,
VictorSanh's avatar
wip  
VictorSanh committed
59
60
61
62
63
64
65
66
                 dropout=0.1,
                 attention_dropout=0.1,
                 activation='gelu',
                 initializer_range=0.02,
                 tie_weights=True,
                 **kwargs):
        super(DilBertconfig, self).__init__(**kwargs)

VictorSanh's avatar
VictorSanh committed
67
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
VictorSanh's avatar
wip  
VictorSanh committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.max_position_embeddings = max_position_embeddings
            self.sinusoidal_pos_embds = sinusoidal_pos_embds
            self.n_layers = n_layers
            self.n_heads = n_heads
            self.dim = dim
            self.dropout = dropout
            self.attention_dropout = attention_dropout
            self.activation = activation
            self.initializer_range = initializer_range
            self.tie_weights = tie_weights
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")


VictorSanh's avatar
VictorSanh committed
90
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
VictorSanh's avatar
wip  
VictorSanh committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def gelu(x):
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))

def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False

class Embeddings(nn.Module):
    def __init__(self,
                 config):
        super(Embeddings, self).__init__()
VictorSanh's avatar
VictorSanh committed
108
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=0)
VictorSanh's avatar
wip  
VictorSanh committed
109
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
VictorSanh's avatar
VictorSanh committed
110
        if config.sinusoidal_pos_embds:
VictorSanh's avatar
wip  
VictorSanh committed
111
112
113
114
115
116
117
118
119
120
121
            create_sinusoidal_embeddings(n_pos=config.max_position_embeddings,
                                         dim=config.dim,
                                         out=self.position_embeddings.weight)

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids):
        """
        Parameters
        ----------
VictorSanh's avatar
VictorSanh committed
122
123
124
125
126
127
128
        input_ids: torch.tensor(bs, max_seq_length)
            The token ids to embed.

        Outputs
        -------
        embeddings: torch.tensor(bs, max_seq_length, dim)
            The embedded tokens (plus position embeddings, no token_type embeddings)
VictorSanh's avatar
wip  
VictorSanh committed
129
130
131
132
133
134
135
136
        """
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)                      # (bs, max_seq_length)

        word_embeddings = self.word_embeddings(input_ids)                   # (bs, max_seq_length, dim)
        position_embeddings = self.position_embeddings(position_ids)        # (bs, max_seq_length, dim)

VictorSanh's avatar
VictorSanh committed
137
138
139
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)             # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings)               # (bs, max_seq_length, dim)
VictorSanh's avatar
wip  
VictorSanh committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return embeddings

class MultiHeadSelfAttention(nn.Module):
    def __init__(self,
                 config):
        super(MultiHeadSelfAttention, self).__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)
        self.output_attentions = config.output_attentions

        assert self.dim % self.n_heads == 0

VictorSanh's avatar
VictorSanh committed
154
155
156
157
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
VictorSanh's avatar
wip  
VictorSanh committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171

    def forward(self,
                query: torch.tensor,
                key: torch.tensor,
                value: torch.tensor,
                mask: torch.tensor):
        """
        Parameters
        ----------
        query: torch.tensor(bs, seq_length, dim)
        key: torch.tensor(bs, seq_length, dim)
        value: torch.tensor(bs, seq_length, dim)
        mask: torch.tensor(bs, seq_length)

VictorSanh's avatar
VictorSanh committed
172
173
        Outputs
        -------
VictorSanh's avatar
wip  
VictorSanh committed
174
175
176
        weights: torch.tensor(bs, n_heads, seq_length, seq_length)
            Attention weights
        context: torch.tensor(bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
177
            Contextualized layer. Optional: only if `output_attentions=True`
VictorSanh's avatar
wip  
VictorSanh committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)
        assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        assert key.size() == value.size()

        dim_per_head = dim // self.n_heads

        assert 2 <= mask.dim() <= 3
        causal = (mask.dim() == 3)
        mask_reshp = (bs, 1, 1, k_length)

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
            return x.transpose(1, 2).contiguous().view(bs, -1, dim)

        q = shape(self.q_lin(query))           # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))             # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))           # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)                     # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2,3))          # (bs, n_heads, q_length, k_length)
        mask = (mask==0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
        scores.masked_fill_(mask, -float('inf'))            # (bs, n_heads, q_length, k_length)

        weights = nn.Softmax(dim=-1)(scores)   # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)        # (bs, n_heads, q_length, k_length)
        context = torch.matmul(weights, v)     # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)             # (bs, q_length, dim)
        context = self.out_lin(context)        # (bs, q_length, dim)

        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
214
            return (context, weights)
VictorSanh's avatar
wip  
VictorSanh committed
215
        else:
VictorSanh's avatar
VictorSanh committed
216
            return (context,)
VictorSanh's avatar
wip  
VictorSanh committed
217
218
219
220
221
222
223
224

class FFN(nn.Module):
    def __init__(self,
                 config):
        super(FFN, self).__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
VictorSanh's avatar
VictorSanh committed
225
226
        assert config.activation in ['relu', 'gelu'], ValueError(f"activation ({config.activation}) must be in ['relu', 'gelu']")
        self.activation = gelu if config.activation == 'gelu' else nn.ReLU()
VictorSanh's avatar
wip  
VictorSanh committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    def forward(self,
                input: torch.tensor):
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self,
                 config):
        super(TransformerBlock, self).__init__()

        self.n_heads = config.n_heads
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        self.dropout = nn.Dropout(p=config.dropout)
        self.activation = config.activation
        self.output_attentions = config.output_attentions

VictorSanh's avatar
VictorSanh committed
248
        assert config.dim % config.n_heads == 0
VictorSanh's avatar
wip  
VictorSanh committed
249

VictorSanh's avatar
VictorSanh committed
250
        self.attention = MultiHeadSelfAttention(config)
VictorSanh's avatar
wip  
VictorSanh committed
251
252
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

VictorSanh's avatar
VictorSanh committed
253
        self.ffn = FFN(config)
VictorSanh's avatar
wip  
VictorSanh committed
254
255
256
257
258
259
260
261
262
263
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)

    def forward(self,
                x: torch.tensor,
                attn_mask: torch.tensor = None):
        """
        Parameters
        ----------
        x: torch.tensor(bs, seq_length, dim)
        attn_mask: torch.tensor(bs, seq_length)
VictorSanh's avatar
VictorSanh committed
264
265
266
267
268
269
270

        Outputs
        -------
        sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
            The attention weights
        ffn_output: torch.tensor(bs, seq_length, dim)
            The output of the transformer block contextualization.
VictorSanh's avatar
wip  
VictorSanh committed
271
272
273
274
        """
        # Self-Attention
        sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
275
            sa_output, sa_weights = sa_output                  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
276
277
278
279
280
281
        sa_output = self.sa_layer_norm(sa_output + x)          # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(sa_output)                             # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)

VictorSanh's avatar
VictorSanh committed
282
        output = (ffn_output)
VictorSanh's avatar
wip  
VictorSanh committed
283
        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
284
285
            output = (sa_weights,) + output
        return output
VictorSanh's avatar
wip  
VictorSanh committed
286
287
288
289
290
291
292

class Transformer(nn.Module):
    def __init__(self,
                 config):
        super(Transformer, self).__init__()
        self.n_layers = config.n_layers
        self.output_attentions = config.output_attentions
VictorSanh's avatar
VictorSanh committed
293
        self.output_hidden_states = config.output_hidden_states
VictorSanh's avatar
wip  
VictorSanh committed
294

VictorSanh's avatar
VictorSanh committed
295
296
        layer = TransformerBlock(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
VictorSanh's avatar
wip  
VictorSanh committed
297
298
299

    def forward(self,
                x: torch.tensor,
VictorSanh's avatar
VictorSanh committed
300
                attn_mask: torch.tensor = None):
VictorSanh's avatar
wip  
VictorSanh committed
301
302
303
304
        """
        Parameters
        ----------
        x: torch.tensor(bs, seq_length, dim)
VictorSanh's avatar
VictorSanh committed
305
            Input sequence embedded.
VictorSanh's avatar
wip  
VictorSanh committed
306
        attn_mask: torch.tensor(bs, seq_length)
VictorSanh's avatar
VictorSanh committed
307
308
309
310
311
312
313
314
315
316
317
318
            Attention mask on the sequence.

        Outputs
        -------
        hidden_state: torch.tensor(bs, seq_length, dim)
            Sequence of hiddens states in the last (top) layer
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if output_hidden_states=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if output_attentions=True
VictorSanh's avatar
wip  
VictorSanh committed
319
        """
VictorSanh's avatar
VictorSanh committed
320
321
        all_hidden_states = ()
        all_attentions = ()
VictorSanh's avatar
wip  
VictorSanh committed
322

VictorSanh's avatar
VictorSanh committed
323
        hidden_state = x
VictorSanh's avatar
wip  
VictorSanh committed
324
        for _, layer_module in enumerate(self.layer):
VictorSanh's avatar
VictorSanh committed
325
            hidden_state = layer_module(x=hidden_state, attn_mask=attn_mask)
VictorSanh's avatar
wip  
VictorSanh committed
326
            if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
327
328
329
                attentions, hidden_state = hidden_state
                all_attentions = all_attentions + (attentions,)
            all_hidden_states = all_hidden_states + (hidden_state,)
VictorSanh's avatar
wip  
VictorSanh committed
330

VictorSanh's avatar
VictorSanh committed
331
332
333
        outputs = (hidden_state,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
VictorSanh's avatar
wip  
VictorSanh committed
334
        if self.output_attentions:
VictorSanh's avatar
VictorSanh committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            outputs = outputs + (all_attentions,)
        return outputs


### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
class DilBertPreTrainedModel(PreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    """
    config_class = DilBertconfig
    pretrained_model_archive_map = DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
    load_tf_weights = None
    base_model_prefix = "dilbert"

    def __init__(self, *inputs, **kwargs):
        super(DilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
    
    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, nn.Embedding):
            if module.weight.requires_grad:
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


DILBERT_START_DOCSTRING = r"""
    Smaller, faster, cheaper, lighter: DilBERT

    For more information on DilBERT, you should check TODO(Victor): Link to Medium

    Parameters:
        config (:class:`~pytorch_transformers.DilBertconfig`): Model configuration class with all the parameters of the model. 
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

DILBERT_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**L ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices oof input sequence tokens in the vocabulary.
            The input sequences should start with `[CLS]` and `[SEP]` tokens.
            
            For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DilBERT.
        **attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
"""

@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
                      DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertModel(DilBertPreTrainedModel):
    def __init__(self, config):
        super(DilBertModel, self).__init__(config)

        self.embeddings = Embeddings(config)   # Embeddings
        self.transformer = Transformer(config) # Encoder

        self.apply(self.init_weights)
VictorSanh's avatar
wip  
VictorSanh committed
401

VictorSanh's avatar
VictorSanh committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor = None):
        """
        Parameters
        ----------
        input_ids: torch.tensor(bs, seq_length)
            Sequences of token ids.
        attention_mask: torch.tensor(bs, seq_length)
            Attention mask on the sequences. Optional: If None, it's like there was no padding.
        
        Outputs
        -------
        hidden_state: torch.tensor(bs, seq_length, dim)
            Sequence of hiddens states in the last (top) layer
        pooled_output: torch.tensor(bs, dim)
            Pooled output: for DilBert, the pooled output is simply the hidden state of the [CLS] token.
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if output_hidden_states=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if output_attentions=True
        """
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
VictorSanh's avatar
wip  
VictorSanh committed
428

VictorSanh's avatar
VictorSanh committed
429
430
431
432
433
434
        embedding_output = self.embeddings(input_ids)   # (bs, seq_length, dim)
        tfmr_output = self.transformer(x=embedding_output,
                                       attn_mask=attention_mask)
        hidden_state = tfmr_output[0]
        pooled_output = hidden_state[:, 0]
        output = (hidden_state, pooled_output) + tfmr_output[1:]
VictorSanh's avatar
wip  
VictorSanh committed
435

VictorSanh's avatar
VictorSanh committed
436
        return output # hidden_state, pooled_output, (hidden_states), (attentions)
VictorSanh's avatar
wip  
VictorSanh committed
437

VictorSanh's avatar
VictorSanh committed
438
439
440
441
442
443
444
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """,
                      DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForMaskedLM(DilBertPreTrainedModel):
    def __init__(self, config):
        super(DilBertForMaskedLM, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
VictorSanh's avatar
wip  
VictorSanh committed
445

VictorSanh's avatar
VictorSanh committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        self.encoder = DilBertModel(config)
        self.vocab_transform = nn.Linear(config.dim, config.dim)
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)

        self.apply(self.init_weights)
        self.tie_weights()

        self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)

    def tie_weights_(self):
        """
        Tying the weights of the vocabulary projection to the base token embeddings.
        """
        if self.config.tie_weights:
            self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor = None,
                masked_lm_labels: torch.tensor = None):
        """
        Parameters
        ----------
        input_ids: torch.tensor(bs, seq_length)
            Token ids.
        attention_mask: torch.tensor(bs, seq_length)
            Attention mask. Optional: If None, it's like there was no padding.
        masked_lm_labels: torch.tensor(bs, seq_length)
            The masked language modeling labels. Optional: If None, no loss is computed.

        Outputs
        -------
        mlm_loss: torch.tensor(1,)
            Masked Language Modeling loss to optimize. 
            Optional: only if `masked_lm_labels` is not None
        prediction_logits: torch.tensor(bs, seq_length, voc_size)
            Token prediction logits
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if `output_hidden_states`=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if `output_attentions`=True
        """
        tfmr_output = self.encoder(input_ids=input_ids,
                                   attention_mask=attention_mask)
        hidden_states = tfmr_output[0]                               # (bs, seq_length, dim)
        prediction_logits = self.vocab_transform(hidden_states)      # (bs, seq_length, dim)
        prediction_logits = gelu(prediction_logits)                  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

        outputs = (prediction_logits, ) + tfmr_output[2:]
        if masked_lm_labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
                                         masked_lm_labels.view(-1))
            outputs = (mlm_loss,) + outputs     

        return outputs # (mlm_loss), prediction_logits, (hidden_states), (attentions)

@add_start_docstrings("""DilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
                         the pooled output) e.g. for GLUE tasks. """,
                      DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForSequenceClassification(DilBertPreTrainedModel):
    def __init__(self, config):
        super(DilBertForSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels

        self.dilbert = DilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.classifier = nn.Linear(config.dim, config.num_labels)
        self.dropout = nn.Dropout(config.seq_classif_dropout)

        self.apply(self.init_weights)

    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor = None,
                labels: torch.tensor = None):
        """
        Parameters
        ----------
        input_ids: torch.tensor(bs, seq_length)
            Token ids.
        attention_mask: torch.tensor(bs, seq_length)
            Attention mask. Optional: If None, it's like there was no padding.
        labels: torch.tensor(bs,)
            Classification Labels: Optional: If None, no loss will be computed.
        
        Outputs
        -------
        loss: torch.tensor(1)
            Sequence classification loss.
            Optional: Is computed only if `labels` is not None.
        logits: torch.tensor(bs, seq_length)
            Classification (or regression if config.num_labels==1) scores
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if `output_hidden_states`=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if `output_attentions`=True        
        """
        dilbert_output = self.dilbert(input_ids=input_ids,
                                      attention_mask=attention_mask)
        pooled_output = dilbert_output[1]                    # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)   # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)             # (bs, dim)
        pooled_output = self.dropout(pooled_output)         # (bs, dim)
        logits = self.classifier(pooled_output)              # (bs, dim)

        outputs = (logits,) + dilbert_output[2:]
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

@add_start_docstrings("""DilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
                         the hidden-states output to compute `span start logits` and `span end logits`). """,
                      DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
VictorSanh's avatar
wip  
VictorSanh committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
class DilBertForQuestionAnswering(DilBertPreTrainedModel):
    def __init__(self, config):
        super(DilBertForQuestionAnswering, self).__init__(config)

        self.dilbert = DilBertModel(config)
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
        assert config.num_labels == 2
        self.dropout = nn.Dropout(config.qa_dropout)

        self.apply(self.init_weights)
        
    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor = None,
                start_positions: torch.tensor = None,
                end_positions: torch.tensor = None):
VictorSanh's avatar
VictorSanh committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        """
        Parameters
        ----------
        input_ids: torch.tensor(bs, seq_length)
            Token ids.
        attention_mask: torch.tensor(bs, seq_length)
            Attention mask. Optional: If None, it's like there was no padding.
        start_positions: torch,tensor(bs)
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
            Optional: if None, no loss is computed.
        end_positions: torch,tensor(bs)
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
            Optional: if None, no loss is computed.

        Outputs
        -------
        loss: torch.tensor(1)
            Question answering loss.
            Optional: Is computed only if `start_positions` and `end_positions` are not None.
        start_logits: torch.tensor(bs, seq_length)
            Span-start scores.
        end_logits: torch.tensor(bs, seq_length)
            Spand-end scores.
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
            Tuple of length n_layers with the hidden states from each layer.
            Optional: only if `output_hidden_states`=True
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
            Tuple of length n_layers with the attention weights from each layer
            Optional: only if `output_attentions`=True
        """
        dilbert_output = self.dilbert(input_ids=input_ids,
                                      attention_mask=attention_mask)
        hidden_states = dilbert_output[0]                                 # (bs, max_query_len, dim)

VictorSanh's avatar
wip  
VictorSanh committed
627
628
629
630
631
632
        hidden_states = self.dropout(hidden_states)                       # (bs, max_query_len, dim)
        logits = self.qa_outputs(hidden_states)                           # (bs, max_query_len, 2)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)                           # (bs, max_query_len)
        end_logits = end_logits.squeeze(-1)                               # (bs, max_query_len)

VictorSanh's avatar
VictorSanh committed
633
        outputs = (start_logits, end_logits,) + dilbert_output[2:]
VictorSanh's avatar
wip  
VictorSanh committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

VictorSanh's avatar
VictorSanh committed
651
        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)