"vscode:/vscode.git/clone" did not exist on "72666130a7369e5171c5092f0dd1a13bafa52b8c"
modeling_xlm.py 44.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM model.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import sys
from io import open

import itertools
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss

33
34
from .modeling_utils import (PretrainedConfig, PreTrainedModel,
                             prune_linear_layer, SequenceSummary, SQuADHead)
35
36
37

logger = logging.getLogger(__name__)

38
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
39
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
40
41
42
43
44
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
45
}
46
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
47
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
48
49
50
51
52
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.bin",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-configl.bin",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.bin",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.bin",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.bin",
53
54
55
56
57
}


class XLMConfig(PretrainedConfig):
    """Configuration class to store the configuration of a `XLMModel`.
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    Args:
        vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
        d_model: Size of the encoder layers and the pooler layer.
        n_layer: Number of hidden layers in the Transformer encoder.
        n_head: Number of attention heads for each attention layer in
            the Transformer encoder.
        d_inner: The size of the "intermediate" (i.e., feed-forward)
            layer in the Transformer encoder.
        ff_activation: The non-linear activation function (function or string) in the
            encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
        untie_r: untie relative position biases
        attn_type: 'bi' for XLM, 'uni' for Transformer-XL

        dropout: The dropout probabilitiy for all fully connected
            layers in the embeddings, encoder, and pooler.
        dropatt: The dropout ratio for the attention
            probabilities.
        max_position_embeddings: The maximum sequence length that this model might
            ever be used with. Typically set this to something large just in case
            (e.g., 512 or 1024 or 2048).
        initializer_range: The sttdev of the truncated_normal_initializer for
            initializing all weight matrices.
        layer_norm_eps: The epsilon used by LayerNorm.

        dropout: float, dropout rate.
        dropatt: float, dropout rate on attention probabilities.
        init: str, the initialization scheme, either "normal" or "uniform".
        init_range: float, initialize the parameters with a uniform distribution
            in [-init_range, init_range]. Only effective when init="uniform".
        init_std: float, initialize the parameters with a normal distribution
            with mean 0 and stddev init_std. Only effective when init="normal".
        mem_len: int, the number of tokens to cache.
        reuse_len: int, the number of tokens in the currect batch to be cached
            and reused in the future.
        bi_data: bool, whether to use bidirectional input pipeline.
            Usually set to True during pretraining and False during finetuning.
        clamp_len: int, clamp all relative distances larger than clamp_len.
            -1 means no clamping.
        same_length: bool, whether to use the same attention length for each token.
98
    """
99
    pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
100
101

    def __init__(self,
thomwolf's avatar
thomwolf committed
102
                 vocab_size_or_config_json_file=30145,
thomwolf's avatar
xlm  
thomwolf committed
103
104
105
106
107
108
109
110
                 n_special=0,
                 emb_dim=2048,
                 n_layers=12,
                 n_heads=16,
                 dropout=0.1,
                 attention_dropout=0.1,
                 gelu_activation=True,
                 sinusoidal_embeddings=False,
thomwolf's avatar
thomwolf committed
111
                 causal=False,
thomwolf's avatar
xlm  
thomwolf committed
112
113
                 asm=False,
                 n_langs=1,
114
                 max_position_embeddings=512,
thomwolf's avatar
thomwolf committed
115
                 embed_init_std=2048 ** -0.5,
thomwolf's avatar
thomwolf committed
116
                 layer_norm_eps=1e-12,
thomwolf's avatar
thomwolf committed
117
118
119
120
121
122
123
                 init_std=0.02,
                 bos_index=0,
                 eos_index=1,
                 pad_index=2,
                 unk_index=3,
                 mask_index=5,
                 is_encoder=True,
thomwolf's avatar
thomwolf committed
124
125
126

                 finetuning_task=None,
                 num_labels=2,
127
                 summary_type='first',
thomwolf's avatar
thomwolf committed
128
                 summary_use_proj=True,
129
130
131
                 summary_activation=None,
                 summary_proj_to_labels=True,
                 summary_first_dropout=0.1,
thomwolf's avatar
thomwolf committed
132
133
                 start_n_top=5,
                 end_n_top=5,
thomwolf's avatar
xlm  
thomwolf committed
134
                 **kwargs):
135
136
        """Constructs XLMConfig.
        """
thomwolf's avatar
xlm  
thomwolf committed
137
138
        super(XLMConfig, self).__init__(**kwargs)

139
140
141
142
143
144
145
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
thomwolf's avatar
xlm  
thomwolf committed
146
147
148
149
150
151
152
153
154
            self.n_words = vocab_size_or_config_json_file
            self.n_special = n_special
            self.emb_dim = emb_dim
            self.n_layers = n_layers
            self.n_heads = n_heads
            self.dropout = dropout
            self.attention_dropout = attention_dropout
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
thomwolf's avatar
thomwolf committed
155
            self.causal = causal
thomwolf's avatar
xlm  
thomwolf committed
156
157
            self.asm = asm
            self.n_langs = n_langs
thomwolf's avatar
thomwolf committed
158
            self.layer_norm_eps = layer_norm_eps
thomwolf's avatar
thomwolf committed
159
160
161
162
163
164
            self.bos_index = bos_index
            self.eos_index = eos_index
            self.pad_index = pad_index
            self.unk_index = unk_index
            self.mask_index = mask_index
            self.is_encoder = is_encoder
165
            self.max_position_embeddings = max_position_embeddings
thomwolf's avatar
thomwolf committed
166
167
            self.embed_init_std = embed_init_std
            self.init_std = init_std
thomwolf's avatar
thomwolf committed
168
169
170
171
172
            self.finetuning_task = finetuning_task
            self.num_labels = num_labels
            self.summary_type = summary_type
            self.summary_use_proj = summary_use_proj
            self.summary_activation = summary_activation
173
174
            self.summary_proj_to_labels = summary_proj_to_labels
            self.summary_first_dropout = summary_first_dropout
thomwolf's avatar
thomwolf committed
175
176
            self.start_n_top = start_n_top
            self.end_n_top = end_n_top
177
178
179
180
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

thomwolf's avatar
xlm  
thomwolf committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    @property
    def total_tokens_embeddings(self):
        return self.n_words + self.n_special

    @property
    def hidden_size(self):
        return self.emb_dim

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def gelu(x):
    """
    GELU activation
    https://arxiv.org/abs/1606.08415
    https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
thomwolf's avatar
thomwolf committed
214
    https://github.com/huggingface/pytorch-transformers/blob/master/modeling.py
215
216
217
218
219
    """
    # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))


thomwolf's avatar
thomwolf committed
220
def get_masks(slen, lengths, causal, padding_mask=None):
221
222
223
224
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    bs = lengths.size(0)
thomwolf's avatar
thomwolf committed
225
226
227
228
229
230
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
        mask = alen < lengths[:, None]
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

thomwolf's avatar
thomwolf committed
249
    def __init__(self, n_heads, dim, config):
thomwolf's avatar
thomwolf committed
250
        super(MultiHeadAttention, self).__init__()
251
        self.layer_id = next(MultiHeadAttention.NEW_ID)
thomwolf's avatar
thomwolf committed
252
        self.output_attentions = config.output_attentions
253
254
        self.dim = dim
        self.n_heads = n_heads
thomwolf's avatar
thomwolf committed
255
        self.dropout = config.attention_dropout
256
257
        assert self.dim % self.n_heads == 0

thomwolf's avatar
thomwolf committed
258
259
260
261
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
262

thomwolf's avatar
thomwolf committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_heads, attention_head_size)
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads

thomwolf's avatar
thomwolf committed
281
    def forward(self, input, mask, kv=None, cache=None, head_mask=None):
282
283
284
285
286
287
288
289
290
291
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
thomwolf's avatar
thomwolf committed
292
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
293
        n_heads = self.n_heads
thomwolf's avatar
thomwolf committed
294
        dim_per_head = self.dim // n_heads
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)                                       # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
331
332
333
334
335

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

336
337
338
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

thomwolf's avatar
xlm  
thomwolf committed
339
340
        outputs = (self.out_lin(context),)
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
341
            outputs = outputs + (weights,)
thomwolf's avatar
xlm  
thomwolf committed
342
        return outputs
343
344
345
346


class TransformerFFN(nn.Module):

thomwolf's avatar
thomwolf committed
347
    def __init__(self, in_dim, dim_hidden, out_dim, config):
thomwolf's avatar
thomwolf committed
348
        super(TransformerFFN, self).__init__()
thomwolf's avatar
thomwolf committed
349
        self.dropout = config.dropout
thomwolf's avatar
thomwolf committed
350
351
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)
thomwolf's avatar
thomwolf committed
352
        self.act = gelu if config.gelu_activation else F.relu
353
354
355
356
357
358
359
360
361

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


362
class XLMPreTrainedModel(PreTrainedModel):
363
364
365
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
366
    config_class = XLMConfig
367
    pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
368
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
369
    base_model_prefix = "transformer"
370
371
372

    def __init__(self, *inputs, **kwargs):
        super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
373
374

    def init_weights(self, module):
thomwolf's avatar
thomwolf committed
375
376
377
378
379
380
381
382
383
        """ Initialize the weights. """
        if isinstance(module, nn.Embedding):
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
        if isinstance(module, nn.Linear):
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.)
thomwolf's avatar
thomwolf committed
384
        if isinstance(module, nn.LayerNorm):
385
386
387
388
389
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class XLMModel(XLMPreTrainedModel):
390
391
    """
    XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
392

393
    Paper: https://arxiv.org/abs/1901.07291
thomwolf's avatar
thomwolf committed
394

395
    Original code: https://github.com/facebookresearch/XLM
thomwolf's avatar
thomwolf committed
396

397
398
399
400
401
    Args:
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
402

403
    Example::
thomwolf's avatar
thomwolf committed
404

thomwolf's avatar
thomwolf committed
405
        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
thomwolf's avatar
thomwolf committed
406
407
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

thomwolf's avatar
thomwolf committed
408
        model = modeling.XLMModel(config=config)
409
410
411
412
413
414
415
416
    """

    ATTRIBUTES = ['encoder', 'eos_index', 'pad_index',  # 'with_output', 
                  'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 
                  'hidden_dim', 'dropout', 'attention_dropout', 'asm',
                  'asm_cutoffs', 'asm_div_value']

    def __init__(self, config):  #, dico, is_encoder, with_output):
thomwolf's avatar
xlm  
thomwolf committed
417
418
419
        super(XLMModel, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
420
421

        # encoder / decoder, output layer
thomwolf's avatar
thomwolf committed
422
423
424
425
        self.is_encoder = config.is_encoder
        self.is_decoder = not config.is_encoder
        if self.is_decoder:
            raise NotImplementedError("Currently XLM can only be used as an encoder")
426
        # self.with_output = with_output
thomwolf's avatar
xlm  
thomwolf committed
427
        self.causal = config.causal
428
429

        # dictionary / languages
thomwolf's avatar
xlm  
thomwolf committed
430
431
432
433
        self.n_langs = config.n_langs
        self.n_words = config.n_words
        self.eos_index = config.eos_index
        self.pad_index = config.pad_index
434
        # self.dico = dico
thomwolf's avatar
thomwolf committed
435
436
        # self.id2lang = config.id2lang
        # self.lang2id = config.lang2id
437
        # assert len(self.dico) == self.n_words
thomwolf's avatar
thomwolf committed
438
        # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
439
440

        # model parameters
thomwolf's avatar
xlm  
thomwolf committed
441
        self.dim = config.emb_dim       # 512 by default
442
        self.hidden_dim = self.dim * 4  # 2048 by default
thomwolf's avatar
xlm  
thomwolf committed
443
444
445
446
        self.n_heads = config.n_heads   # 8 by default
        self.n_layers = config.n_layers
        self.dropout = config.dropout
        self.attention_dropout = config.attention_dropout
447
448
449
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
thomwolf's avatar
thomwolf committed
450
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
thomwolf's avatar
xlm  
thomwolf committed
451
452
453
        if config.sinusoidal_embeddings:
            create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
        if config.n_langs > 1:
thomwolf's avatar
thomwolf committed
454
455
456
            self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
        self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
457
458
459
460
461
462

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
thomwolf's avatar
thomwolf committed
463
464
465
        # if self.is_decoder:
        #     self.layer_norm15 = nn.ModuleList()
        #     self.encoder_attn = nn.ModuleList()
466
467

        for _ in range(self.n_layers):
thomwolf's avatar
thomwolf committed
468
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
thomwolf's avatar
thomwolf committed
469
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
470
            # if self.is_decoder:
thomwolf's avatar
thomwolf committed
471
            #     self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
472
473
            #     self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
thomwolf's avatar
thomwolf committed
474
475
476
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))

        self.apply(self.init_weights)
477

thomwolf's avatar
thomwolf committed
478
479
480
481
482
483
484
485
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.attentions[layer].prune_heads(heads)

thomwolf's avatar
thomwolf committed
486
487
    def forward(self, input_ids, lengths=None, positions=None, langs=None,
                token_type_ids=None, attention_mask=None, cache=None, head_mask=None):  # src_enc=None, src_len=None, 
488
        """
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Parameters:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            `lengths`: ``torch.LongTensor`` of size ``bs``, containing the length of each sentence
            `positions`: ``torch.LongTensor`` of size ``(bs, slen)``, containing word positions
            `langs`: ``torch.LongTensor`` of size ``(bs, slen)``, containing language IDs
            `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
                selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
                input sequence length in the current batch. It's the mask that we typically use for attention when
                a batch has varying length sentences.
            `cache`: TODO
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.


        Returns:
            A ``tuple(encoded_layers, pooled_output)``, with

            ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:

                - ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states at the end \
                of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
                encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],

                - ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
                to the last attention block of shape [batch_size, sequence_length, hidden_size],

            ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a
            classifier pre-trained on top of the hidden state associated to the first character of the
            input (`CLS`) to train on the Next-Sentence task (see XLM's paper).

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
            # or
            all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
536
        """
thomwolf's avatar
thomwolf committed
537
        if lengths is None:
thomwolf's avatar
thomwolf committed
538
            lengths = (input_ids != self.pad_index).sum(dim=1).long()
thomwolf's avatar
xlm  
thomwolf committed
539
        # mask = input_ids != self.pad_index
540
541

        # check inputs
thomwolf's avatar
xlm  
thomwolf committed
542
        bs, slen = input_ids.size()
543
544
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
thomwolf's avatar
xlm  
thomwolf committed
545
        # input_ids = input_ids.transpose(0, 1)  # batch size as dimension 0
thomwolf's avatar
thomwolf committed
546
547
548
549
        # assert (src_enc is None) == (src_len is None)
        # if src_enc is not None:
        #     assert self.is_decoder
        #     assert src_enc.size(0) == bs
550
551

        # generate masks
thomwolf's avatar
thomwolf committed
552
        mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
thomwolf's avatar
thomwolf committed
553
554
        # if self.is_decoder and src_enc is not None:
        #     src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
555
556
557

        # positions
        if positions is None:
thomwolf's avatar
thomwolf committed
558
            positions = input_ids.new((slen,)).long()
559
560
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
thomwolf's avatar
thomwolf committed
561
562
            assert positions.size() == (bs, slen)  # (slen, bs)
            # positions = positions.transpose(0, 1)
563
564

        # langs
thomwolf's avatar
thomwolf committed
565
566
567
        assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids"
        if token_type_ids is not None:
            langs = token_type_ids
568
        if langs is not None:
thomwolf's avatar
thomwolf committed
569
570
            assert langs.size() == (bs, slen)  # (slen, bs)
            # langs = langs.transpose(0, 1)
571

thomwolf's avatar
thomwolf committed
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layers

587
588
589
        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
thomwolf's avatar
xlm  
thomwolf committed
590
            input_ids = input_ids[:, -_slen:]
591
592
593
594
595
596
597
            positions = positions[:, -_slen:]
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
thomwolf's avatar
xlm  
thomwolf committed
598
        tensor = self.embeddings(input_ids)
599
600
601
602
603
604
605
606
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        if langs is not None:
            tensor = tensor + self.lang_embeddings(langs)
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
thomwolf's avatar
thomwolf committed
607
608
        hidden_states = ()
        attentions = ()
609
        for i in range(self.n_layers):
thomwolf's avatar
thomwolf committed
610
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
611
                hidden_states = hidden_states + (tensor,)
612
613

            # self attention
thomwolf's avatar
thomwolf committed
614
615
616
            attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
            attn = attn_outputs[0]
            if self.output_attentions:
thomwolf's avatar
thomwolf committed
617
                attentions = attentions + (attn_outputs[1],)
618
619
620
621
622
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
thomwolf's avatar
thomwolf committed
623
624
625
626
627
            # if self.is_decoder and src_enc is not None:
            #     attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
            #     attn = F.dropout(attn, p=self.dropout, training=self.training)
            #     tensor = tensor + attn
            #     tensor = self.layer_norm15[i](tensor)
628
629
630
631
632
633

            # FFN
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)
            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

thomwolf's avatar
thomwolf committed
634
635
        # Add last hidden state
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
636
            hidden_states = hidden_states + (tensor,)
thomwolf's avatar
thomwolf committed
637

638
639
640
641
642
        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
thomwolf's avatar
thomwolf committed
643
        # tensor = tensor.transpose(0, 1)
644

thomwolf's avatar
thomwolf committed
645
        outputs = (tensor,)
646
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
647
            outputs = outputs + (hidden_states,)
thomwolf's avatar
thomwolf committed
648
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
649
            outputs = outputs + (attentions,)
thomwolf's avatar
thomwolf committed
650
        return outputs  # outputs, (hidden_states), (attentions)
651
652
653
654
655
656


class XLMPredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
thomwolf's avatar
xlm  
thomwolf committed
657
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
658
        super(XLMPredLayer, self).__init__()
thomwolf's avatar
xlm  
thomwolf committed
659
660
661
662
        self.asm = config.asm
        self.n_words = config.n_words
        self.pad_index = config.pad_index
        dim = config.emb_dim
663

thomwolf's avatar
xlm  
thomwolf committed
664
        if config.asm is False:
thomwolf's avatar
thomwolf committed
665
            self.proj = nn.Linear(dim, config.n_words, bias=True)
666
667
668
        else:
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
thomwolf's avatar
xlm  
thomwolf committed
669
670
671
                n_classes=config.n_words,
                cutoffs=config.asm_cutoffs,
                div_value=config.asm_div_value,
672
673
674
                head_bias=True,  # default is False
            )

thomwolf's avatar
thomwolf committed
675
676
    def forward(self, x, y=None):
        """ Compute the loss, and optionally the scores.
677
        """
thomwolf's avatar
thomwolf committed
678
        outputs = ()
679
680
        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
thomwolf's avatar
thomwolf committed
681
682
683
684
            outputs = (scores,) + outputs
            if y is not None:
                loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
                outputs = (loss,) + outputs
685
        else:
thomwolf's avatar
thomwolf committed
686
687
688
689
690
            scores = self.proj.log_prob(x)
            outputs = (scores,) + outputs
            if y is not None:
                _, loss = self.proj(x, y)
                outputs = (loss,) + outputs
691

thomwolf's avatar
thomwolf committed
692
        return outputs
693

thomwolf's avatar
thomwolf committed
694
695

class XLMWithLMHeadModel(XLMPreTrainedModel):
thomwolf's avatar
xlm  
thomwolf committed
696
    """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
thomwolf's avatar
thomwolf committed
697

698
699
700
701
702
    Paper: https://arxiv.org/abs/1901.07291

    Original code: https://github.com/facebookresearch/XLM

    Args:
thomwolf's avatar
xlm  
thomwolf committed
703
704
705
706
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
707

708
    Example::
thomwolf's avatar
thomwolf committed
709

710
711
        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
thomwolf's avatar
thomwolf committed
712

713
        model = modeling.XLMModel(config=config)
thomwolf's avatar
xlm  
thomwolf committed
714
715
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
716
        super(XLMWithLMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
717
        self.torchscript = config.torchscript
718

thomwolf's avatar
xlm  
thomwolf committed
719
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
720
        self.pred_layer = XLMPredLayer(config)
721
722
723
724
725
726
727

        self.apply(self.init_weights)
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the embeddings
        """
thomwolf's avatar
thomwolf committed
728
729
730
731
        if self.torchscript:
            self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
        else:
            self.pred_layer.proj.weight = self.transformer.embeddings.weight
732

thomwolf's avatar
thomwolf committed
733
734
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, labels=None, head_mask=None):
735
736
        """
        Args:
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            `lengths`: TODO
            `positions`: TODO
            `langs`: TODO
            `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
                selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
                input sequence length in the current batch. It's the mask that we typically use for attention when
                a batch has varying length sentences.
            `cache`: TODO
            `labels`: TODO
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.


        Returns:
            A ``tuple(encoded_layers, pooled_output)``, with

                ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:

                    If ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \
                    at the end of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
                    encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],

                    If ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
                    to the last attention block of shape [batch_size, sequence_length, hidden_size],

                ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a \
                classifier pre-trained on top of the hidden state associated to the first character of the \
                input (`CLS`) to train on the Next-Sentence task (see XLM's paper).

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
            # or
            all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
782
        """
thomwolf's avatar
thomwolf committed
783
784
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
785

786
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
787
788
        outputs = self.pred_layer(output, labels)
        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
789

790
        return outputs
791
792
793
794
795


class XLMForSequenceClassification(XLMPreTrainedModel):
    """XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").

796
    Args:
797
798
799
800
801
802
803
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
        `summary_type`: str, "last", "first", "mean", or "attn". The method
            to pool the input to get a vector representation. Default: last

804
805
806
807
808
809
810
811
812


    Example::

        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768,
            n_layer=12, num_attention_heads=12, intermediate_size=3072)

        model = modeling.XLMModel(config=config)

813
    """
thomwolf's avatar
xlm  
thomwolf committed
814
    def __init__(self, config):
815
        super(XLMForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
816
        self.num_labels = config.num_labels
817

thomwolf's avatar
xlm  
thomwolf committed
818
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
819
        self.sequence_summary = SequenceSummary(config)
thomwolf's avatar
thomwolf committed
820

821
822
        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
823
824
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, labels=None, head_mask=None):
825
826
        """
        Args:
827
828
829
830
            input_ids: TODO
            lengths: TODO
            positions: TODO
            langs: TODO
831
832
833
            token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
            attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
                but with 1 for real tokens and 0 for padding.
thomwolf's avatar
thomwolf committed
834
                Added for easy compatibility with the XLM model (which uses this negative masking).
835
                You can only uses one among `input_mask` and `attention_mask`
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
            cache: TODO
            labels: TODO
            head_mask: TODO


        Returns:
            A ``tuple(logits_or_loss, new_mems)``. If ``labels`` is ``None``, return token logits with shape
            [batch_size, sequence_length]. If it isn't ``None``, return the ``CrossEntropy`` loss with the targets.

            ``new_mems`` is a list (num layers) of updated mem states at the entry of each layer \
            each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \
            Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels``

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
857
        """
thomwolf's avatar
thomwolf committed
858
859
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
860

861
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
862
        logits = self.sequence_summary(output)
863

thomwolf's avatar
thomwolf committed
864
        outputs = (logits,) + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
865

866
867
868
869
870
871
872
873
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
874
            outputs = (loss,) + outputs
875

876
        return outputs
877
878
879


class XLMForQuestionAnswering(XLMPreTrainedModel):
880
881
    """
    XLM model for Question Answering (span extraction).
882
883
884
    This module is composed of the XLM model with a linear layer on top of
    the sequence output that computes start_logits and end_logits

885
    Args:
886
887
888
889
890
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False

891
892
893
894
895
896
897
898


    Example::

        config = XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

        model = XLMForQuestionAnswering(config)
899
    """
thomwolf's avatar
thomwolf committed
900
    def __init__(self, config):
901
        super(XLMForQuestionAnswering, self).__init__(config)
902

thomwolf's avatar
xlm  
thomwolf committed
903
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
904
        self.qa_outputs = SQuADHead(config)
thomwolf's avatar
xlm  
thomwolf committed
905

906
907
        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
908
909
910
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, start_positions=None, end_positions=None,
                cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
911

912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            input_ids: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            lengths: TODO
            positions: TODO
            langs: TODO
            token_type_ids: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
                but with 1 for real tokens and 0 for padding.
                Added for easy compatibility with the XLM model (which uses this negative masking).
                You can only uses one among `input_mask` and `attention_mask`
            cache: TODO
            start_positions: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
                Positions are clamped to the length of the sequence and position outside of the sequence are not taken
                into account for computing the loss.
            end_positions: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
                Positions are clamped to the length of the sequence and position outside of the sequence are not taken
                into account for computing the loss.
            cls_index: TODO
            is_impossible: TODO
            p_mask: TODO
            head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.

        Returns:
            Either the ``total_loss`` or a ``tuple(start_logits, end_logits)``

                if ``start_positions`` and ``end_positions`` are not ``None``, \
                outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.

                if ``start_positions`` or ``end_positions`` is ``None``:
                Outputs a ``tuple(start_logits, end_logits)`` which are the logits respectively for the start and end
                position tokens of shape [batch_size, sequence_length].

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
            # or
            start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
        """

thomwolf's avatar
thomwolf committed
964
965
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
966

967
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
968
969
970
971
972

        outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
                                  cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)

        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
973
974

        return outputs