modeling_xlm.py 44.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM model.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import sys
from io import open

import itertools
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss

33
34
from .modeling_utils import (PretrainedConfig, PreTrainedModel,
                             prune_linear_layer, SequenceSummary, SQuADHead)
35
36
37

logger = logging.getLogger(__name__)

38
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
39
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
40
41
42
43
44
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
45
46
    'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
    'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
47
}
48
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
49
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
50
51
52
53
54
55
56
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-configl.json",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
    'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
    'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
57
58
59
60
61
}


class XLMConfig(PretrainedConfig):
    """Configuration class to store the configuration of a `XLMModel`.
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    Args:
        vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
        d_model: Size of the encoder layers and the pooler layer.
        n_layer: Number of hidden layers in the Transformer encoder.
        n_head: Number of attention heads for each attention layer in
            the Transformer encoder.
        d_inner: The size of the "intermediate" (i.e., feed-forward)
            layer in the Transformer encoder.
        ff_activation: The non-linear activation function (function or string) in the
            encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
        untie_r: untie relative position biases
        attn_type: 'bi' for XLM, 'uni' for Transformer-XL

        dropout: The dropout probabilitiy for all fully connected
            layers in the embeddings, encoder, and pooler.
        dropatt: The dropout ratio for the attention
            probabilities.
        max_position_embeddings: The maximum sequence length that this model might
            ever be used with. Typically set this to something large just in case
            (e.g., 512 or 1024 or 2048).
        initializer_range: The sttdev of the truncated_normal_initializer for
            initializing all weight matrices.
        layer_norm_eps: The epsilon used by LayerNorm.

        dropout: float, dropout rate.
        dropatt: float, dropout rate on attention probabilities.
        init: str, the initialization scheme, either "normal" or "uniform".
        init_range: float, initialize the parameters with a uniform distribution
            in [-init_range, init_range]. Only effective when init="uniform".
        init_std: float, initialize the parameters with a normal distribution
            with mean 0 and stddev init_std. Only effective when init="normal".
        mem_len: int, the number of tokens to cache.
        reuse_len: int, the number of tokens in the currect batch to be cached
            and reused in the future.
        bi_data: bool, whether to use bidirectional input pipeline.
            Usually set to True during pretraining and False during finetuning.
        clamp_len: int, clamp all relative distances larger than clamp_len.
            -1 means no clamping.
        same_length: bool, whether to use the same attention length for each token.
102
    """
103
    pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
104
105

    def __init__(self,
thomwolf's avatar
thomwolf committed
106
                 vocab_size_or_config_json_file=30145,
thomwolf's avatar
xlm  
thomwolf committed
107
108
109
110
111
112
113
                 emb_dim=2048,
                 n_layers=12,
                 n_heads=16,
                 dropout=0.1,
                 attention_dropout=0.1,
                 gelu_activation=True,
                 sinusoidal_embeddings=False,
thomwolf's avatar
thomwolf committed
114
                 causal=False,
thomwolf's avatar
xlm  
thomwolf committed
115
116
                 asm=False,
                 n_langs=1,
117
                 max_position_embeddings=512,
thomwolf's avatar
thomwolf committed
118
                 embed_init_std=2048 ** -0.5,
thomwolf's avatar
thomwolf committed
119
                 layer_norm_eps=1e-12,
thomwolf's avatar
thomwolf committed
120
121
122
123
124
125
126
                 init_std=0.02,
                 bos_index=0,
                 eos_index=1,
                 pad_index=2,
                 unk_index=3,
                 mask_index=5,
                 is_encoder=True,
thomwolf's avatar
thomwolf committed
127
128
129

                 finetuning_task=None,
                 num_labels=2,
130
                 summary_type='first',
thomwolf's avatar
thomwolf committed
131
                 summary_use_proj=True,
132
133
134
                 summary_activation=None,
                 summary_proj_to_labels=True,
                 summary_first_dropout=0.1,
thomwolf's avatar
thomwolf committed
135
136
                 start_n_top=5,
                 end_n_top=5,
thomwolf's avatar
xlm  
thomwolf committed
137
                 **kwargs):
138
139
        """Constructs XLMConfig.
        """
thomwolf's avatar
xlm  
thomwolf committed
140
141
        super(XLMConfig, self).__init__(**kwargs)

142
143
144
145
146
147
148
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
thomwolf's avatar
xlm  
thomwolf committed
149
150
151
152
153
154
155
156
            self.n_words = vocab_size_or_config_json_file
            self.emb_dim = emb_dim
            self.n_layers = n_layers
            self.n_heads = n_heads
            self.dropout = dropout
            self.attention_dropout = attention_dropout
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
thomwolf's avatar
thomwolf committed
157
            self.causal = causal
thomwolf's avatar
xlm  
thomwolf committed
158
159
            self.asm = asm
            self.n_langs = n_langs
thomwolf's avatar
thomwolf committed
160
            self.layer_norm_eps = layer_norm_eps
thomwolf's avatar
thomwolf committed
161
162
163
164
165
166
            self.bos_index = bos_index
            self.eos_index = eos_index
            self.pad_index = pad_index
            self.unk_index = unk_index
            self.mask_index = mask_index
            self.is_encoder = is_encoder
167
            self.max_position_embeddings = max_position_embeddings
thomwolf's avatar
thomwolf committed
168
169
            self.embed_init_std = embed_init_std
            self.init_std = init_std
thomwolf's avatar
thomwolf committed
170
171
172
173
174
            self.finetuning_task = finetuning_task
            self.num_labels = num_labels
            self.summary_type = summary_type
            self.summary_use_proj = summary_use_proj
            self.summary_activation = summary_activation
175
176
            self.summary_proj_to_labels = summary_proj_to_labels
            self.summary_first_dropout = summary_first_dropout
thomwolf's avatar
thomwolf committed
177
178
            self.start_n_top = start_n_top
            self.end_n_top = end_n_top
179
180
181
182
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

thomwolf's avatar
xlm  
thomwolf committed
183
    @property
thomwolf's avatar
thomwolf committed
184
185
    def vocab_size(self):
        return self.n_words
thomwolf's avatar
xlm  
thomwolf committed
186
187
188
189
190
191
192
193
194
195
196
197
198

    @property
    def hidden_size(self):
        return self.emb_dim

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def gelu(x):
    """
    GELU activation
    https://arxiv.org/abs/1606.08415
    https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
thomwolf's avatar
thomwolf committed
216
    https://github.com/huggingface/pytorch-transformers/blob/master/modeling.py
217
218
219
220
221
    """
    # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))


thomwolf's avatar
thomwolf committed
222
def get_masks(slen, lengths, causal, padding_mask=None):
223
224
225
226
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    bs = lengths.size(0)
thomwolf's avatar
thomwolf committed
227
228
229
230
231
232
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
        mask = alen < lengths[:, None]
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

thomwolf's avatar
thomwolf committed
251
    def __init__(self, n_heads, dim, config):
thomwolf's avatar
thomwolf committed
252
        super(MultiHeadAttention, self).__init__()
253
        self.layer_id = next(MultiHeadAttention.NEW_ID)
thomwolf's avatar
thomwolf committed
254
        self.output_attentions = config.output_attentions
255
256
        self.dim = dim
        self.n_heads = n_heads
thomwolf's avatar
thomwolf committed
257
        self.dropout = config.attention_dropout
258
259
        assert self.dim % self.n_heads == 0

thomwolf's avatar
thomwolf committed
260
261
262
263
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
264

thomwolf's avatar
thomwolf committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_heads, attention_head_size)
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads

thomwolf's avatar
thomwolf committed
283
    def forward(self, input, mask, kv=None, cache=None, head_mask=None):
284
285
286
287
288
289
290
291
292
293
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
thomwolf's avatar
thomwolf committed
294
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
295
        n_heads = self.n_heads
thomwolf's avatar
thomwolf committed
296
        dim_per_head = self.dim // n_heads
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)                                       # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
333
334
335
336
337

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

338
339
340
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

thomwolf's avatar
xlm  
thomwolf committed
341
342
        outputs = (self.out_lin(context),)
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
343
            outputs = outputs + (weights,)
thomwolf's avatar
xlm  
thomwolf committed
344
        return outputs
345
346
347
348


class TransformerFFN(nn.Module):

thomwolf's avatar
thomwolf committed
349
    def __init__(self, in_dim, dim_hidden, out_dim, config):
thomwolf's avatar
thomwolf committed
350
        super(TransformerFFN, self).__init__()
thomwolf's avatar
thomwolf committed
351
        self.dropout = config.dropout
thomwolf's avatar
thomwolf committed
352
353
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)
thomwolf's avatar
thomwolf committed
354
        self.act = gelu if config.gelu_activation else F.relu
355
356
357
358
359
360
361
362
363

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


364
class XLMPreTrainedModel(PreTrainedModel):
365
366
367
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
368
    config_class = XLMConfig
369
    pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
370
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
371
    base_model_prefix = "transformer"
372
373
374

    def __init__(self, *inputs, **kwargs):
        super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
375
376

    def init_weights(self, module):
thomwolf's avatar
thomwolf committed
377
378
379
380
381
382
383
384
385
        """ Initialize the weights. """
        if isinstance(module, nn.Embedding):
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
        if isinstance(module, nn.Linear):
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.)
thomwolf's avatar
thomwolf committed
386
        if isinstance(module, nn.LayerNorm):
387
388
389
390
391
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class XLMModel(XLMPreTrainedModel):
392
393
    """
    XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
394

395
    Paper: https://arxiv.org/abs/1901.07291
thomwolf's avatar
thomwolf committed
396

397
    Original code: https://github.com/facebookresearch/XLM
thomwolf's avatar
thomwolf committed
398

399
400
401
402
403
    Args:
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
404

405
    Example::
thomwolf's avatar
thomwolf committed
406

thomwolf's avatar
thomwolf committed
407
        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
thomwolf's avatar
thomwolf committed
408
409
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

thomwolf's avatar
thomwolf committed
410
        model = modeling.XLMModel(config=config)
411
412
413
414
415
416
417
418
    """

    ATTRIBUTES = ['encoder', 'eos_index', 'pad_index',  # 'with_output', 
                  'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 
                  'hidden_dim', 'dropout', 'attention_dropout', 'asm',
                  'asm_cutoffs', 'asm_div_value']

    def __init__(self, config):  #, dico, is_encoder, with_output):
thomwolf's avatar
xlm  
thomwolf committed
419
420
421
        super(XLMModel, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
422
423

        # encoder / decoder, output layer
thomwolf's avatar
thomwolf committed
424
425
426
427
        self.is_encoder = config.is_encoder
        self.is_decoder = not config.is_encoder
        if self.is_decoder:
            raise NotImplementedError("Currently XLM can only be used as an encoder")
428
        # self.with_output = with_output
thomwolf's avatar
xlm  
thomwolf committed
429
        self.causal = config.causal
430
431

        # dictionary / languages
thomwolf's avatar
xlm  
thomwolf committed
432
433
434
435
        self.n_langs = config.n_langs
        self.n_words = config.n_words
        self.eos_index = config.eos_index
        self.pad_index = config.pad_index
436
        # self.dico = dico
thomwolf's avatar
thomwolf committed
437
438
        # self.id2lang = config.id2lang
        # self.lang2id = config.lang2id
439
        # assert len(self.dico) == self.n_words
thomwolf's avatar
thomwolf committed
440
        # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
441
442

        # model parameters
thomwolf's avatar
xlm  
thomwolf committed
443
        self.dim = config.emb_dim       # 512 by default
444
        self.hidden_dim = self.dim * 4  # 2048 by default
thomwolf's avatar
xlm  
thomwolf committed
445
446
447
448
        self.n_heads = config.n_heads   # 8 by default
        self.n_layers = config.n_layers
        self.dropout = config.dropout
        self.attention_dropout = config.attention_dropout
449
450
451
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
thomwolf's avatar
thomwolf committed
452
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
thomwolf's avatar
xlm  
thomwolf committed
453
454
455
        if config.sinusoidal_embeddings:
            create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
        if config.n_langs > 1:
thomwolf's avatar
thomwolf committed
456
457
458
            self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
        self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
459
460
461
462
463
464

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
thomwolf's avatar
thomwolf committed
465
466
467
        # if self.is_decoder:
        #     self.layer_norm15 = nn.ModuleList()
        #     self.encoder_attn = nn.ModuleList()
468
469

        for _ in range(self.n_layers):
thomwolf's avatar
thomwolf committed
470
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
thomwolf's avatar
thomwolf committed
471
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
472
            # if self.is_decoder:
thomwolf's avatar
thomwolf committed
473
            #     self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
474
475
            #     self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
thomwolf's avatar
thomwolf committed
476
477
478
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))

        self.apply(self.init_weights)
479

thomwolf's avatar
thomwolf committed
480
481
482
    def _resize_token_embeddings(self, new_num_tokens):
        self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)

thomwolf's avatar
thomwolf committed
483
484
485
486
487
488
489
490
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.attentions[layer].prune_heads(heads)

thomwolf's avatar
thomwolf committed
491
492
    def forward(self, input_ids, lengths=None, positions=None, langs=None,
                token_type_ids=None, attention_mask=None, cache=None, head_mask=None):  # src_enc=None, src_len=None, 
493
        """
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Parameters:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            `lengths`: ``torch.LongTensor`` of size ``bs``, containing the length of each sentence
            `positions`: ``torch.LongTensor`` of size ``(bs, slen)``, containing word positions
            `langs`: ``torch.LongTensor`` of size ``(bs, slen)``, containing language IDs
            `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
                selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
                input sequence length in the current batch. It's the mask that we typically use for attention when
                a batch has varying length sentences.
            `cache`: TODO
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.


        Returns:
            A ``tuple(encoded_layers, pooled_output)``, with

            ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:

                - ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states at the end \
                of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
                encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],

                - ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
                to the last attention block of shape [batch_size, sequence_length, hidden_size],

            ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a
            classifier pre-trained on top of the hidden state associated to the first character of the
            input (`CLS`) to train on the Next-Sentence task (see XLM's paper).

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
            # or
            all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
541
        """
thomwolf's avatar
thomwolf committed
542
        if lengths is None:
thomwolf's avatar
thomwolf committed
543
            lengths = (input_ids != self.pad_index).sum(dim=1).long()
thomwolf's avatar
xlm  
thomwolf committed
544
        # mask = input_ids != self.pad_index
545
546

        # check inputs
thomwolf's avatar
xlm  
thomwolf committed
547
        bs, slen = input_ids.size()
548
549
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
thomwolf's avatar
xlm  
thomwolf committed
550
        # input_ids = input_ids.transpose(0, 1)  # batch size as dimension 0
thomwolf's avatar
thomwolf committed
551
552
553
554
        # assert (src_enc is None) == (src_len is None)
        # if src_enc is not None:
        #     assert self.is_decoder
        #     assert src_enc.size(0) == bs
555
556

        # generate masks
thomwolf's avatar
thomwolf committed
557
        mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
thomwolf's avatar
thomwolf committed
558
559
        # if self.is_decoder and src_enc is not None:
        #     src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
560
561
562

        # positions
        if positions is None:
thomwolf's avatar
thomwolf committed
563
            positions = input_ids.new((slen,)).long()
564
565
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
thomwolf's avatar
thomwolf committed
566
567
            assert positions.size() == (bs, slen)  # (slen, bs)
            # positions = positions.transpose(0, 1)
568
569

        # langs
thomwolf's avatar
thomwolf committed
570
571
572
        assert langs is None or token_type_ids is None, "You can only use one among langs and token_type_ids"
        if token_type_ids is not None:
            langs = token_type_ids
573
        if langs is not None:
thomwolf's avatar
thomwolf committed
574
575
            assert langs.size() == (bs, slen)  # (slen, bs)
            # langs = langs.transpose(0, 1)
576

thomwolf's avatar
thomwolf committed
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layers

592
593
594
        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
thomwolf's avatar
xlm  
thomwolf committed
595
            input_ids = input_ids[:, -_slen:]
596
597
598
599
600
601
602
            positions = positions[:, -_slen:]
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
thomwolf's avatar
xlm  
thomwolf committed
603
        tensor = self.embeddings(input_ids)
604
605
606
607
608
609
610
611
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        if langs is not None:
            tensor = tensor + self.lang_embeddings(langs)
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
thomwolf's avatar
thomwolf committed
612
613
        hidden_states = ()
        attentions = ()
614
        for i in range(self.n_layers):
thomwolf's avatar
thomwolf committed
615
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
616
                hidden_states = hidden_states + (tensor,)
617
618

            # self attention
thomwolf's avatar
thomwolf committed
619
620
621
            attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
            attn = attn_outputs[0]
            if self.output_attentions:
thomwolf's avatar
thomwolf committed
622
                attentions = attentions + (attn_outputs[1],)
623
624
625
626
627
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
thomwolf's avatar
thomwolf committed
628
629
630
631
632
            # if self.is_decoder and src_enc is not None:
            #     attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
            #     attn = F.dropout(attn, p=self.dropout, training=self.training)
            #     tensor = tensor + attn
            #     tensor = self.layer_norm15[i](tensor)
633
634
635
636
637
638

            # FFN
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)
            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

thomwolf's avatar
thomwolf committed
639
640
        # Add last hidden state
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
641
            hidden_states = hidden_states + (tensor,)
thomwolf's avatar
thomwolf committed
642

643
644
645
646
647
        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
thomwolf's avatar
thomwolf committed
648
        # tensor = tensor.transpose(0, 1)
649

thomwolf's avatar
thomwolf committed
650
        outputs = (tensor,)
651
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
652
            outputs = outputs + (hidden_states,)
thomwolf's avatar
thomwolf committed
653
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
654
            outputs = outputs + (attentions,)
thomwolf's avatar
thomwolf committed
655
        return outputs  # outputs, (hidden_states), (attentions)
656
657
658
659
660
661


class XLMPredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
thomwolf's avatar
xlm  
thomwolf committed
662
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
663
        super(XLMPredLayer, self).__init__()
thomwolf's avatar
xlm  
thomwolf committed
664
665
666
667
        self.asm = config.asm
        self.n_words = config.n_words
        self.pad_index = config.pad_index
        dim = config.emb_dim
668

thomwolf's avatar
xlm  
thomwolf committed
669
        if config.asm is False:
thomwolf's avatar
thomwolf committed
670
            self.proj = nn.Linear(dim, config.n_words, bias=True)
671
672
673
        else:
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
thomwolf's avatar
xlm  
thomwolf committed
674
675
676
                n_classes=config.n_words,
                cutoffs=config.asm_cutoffs,
                div_value=config.asm_div_value,
677
678
679
                head_bias=True,  # default is False
            )

thomwolf's avatar
thomwolf committed
680
681
    def forward(self, x, y=None):
        """ Compute the loss, and optionally the scores.
682
        """
thomwolf's avatar
thomwolf committed
683
        outputs = ()
684
685
        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
thomwolf's avatar
thomwolf committed
686
687
688
689
            outputs = (scores,) + outputs
            if y is not None:
                loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
                outputs = (loss,) + outputs
690
        else:
thomwolf's avatar
thomwolf committed
691
692
693
694
695
            scores = self.proj.log_prob(x)
            outputs = (scores,) + outputs
            if y is not None:
                _, loss = self.proj(x, y)
                outputs = (loss,) + outputs
696

thomwolf's avatar
thomwolf committed
697
        return outputs
698

thomwolf's avatar
thomwolf committed
699
700

class XLMWithLMHeadModel(XLMPreTrainedModel):
thomwolf's avatar
xlm  
thomwolf committed
701
    """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
thomwolf's avatar
thomwolf committed
702

703
704
705
706
707
    Paper: https://arxiv.org/abs/1901.07291

    Original code: https://github.com/facebookresearch/XLM

    Args:
thomwolf's avatar
xlm  
thomwolf committed
708
709
710
711
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
712

713
    Example::
thomwolf's avatar
thomwolf committed
714

715
716
        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
thomwolf's avatar
thomwolf committed
717

718
        model = modeling.XLMModel(config=config)
thomwolf's avatar
xlm  
thomwolf committed
719
720
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
721
        super(XLMWithLMHeadModel, self).__init__(config)
thomwolf's avatar
xlm  
thomwolf committed
722
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
723
        self.pred_layer = XLMPredLayer(config)
724
725
726
727
728
729
730

        self.apply(self.init_weights)
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the embeddings
        """
thomwolf's avatar
thomwolf committed
731
        if self.config.torchscript:
thomwolf's avatar
thomwolf committed
732
733
734
            self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
        else:
            self.pred_layer.proj.weight = self.transformer.embeddings.weight
735

thomwolf's avatar
thomwolf committed
736
737
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, labels=None, head_mask=None):
738
739
        """
        Args:
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            `lengths`: TODO
            `positions`: TODO
            `langs`: TODO
            `token_type_ids`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            `attention_mask`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with indices
                selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
                input sequence length in the current batch. It's the mask that we typically use for attention when
                a batch has varying length sentences.
            `cache`: TODO
            `labels`: TODO
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.


        Returns:
            A ``tuple(encoded_layers, pooled_output)``, with

                ``encoded_layers``: controlled by ``output_all_encoded_layers`` argument:

                    If ``output_all_encoded_layers=True``: outputs a list of the full sequences of encoded-hidden-states \
                    at the end of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each \
                    encoded-hidden-state is a ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size],

                    If ``output_all_encoded_layers=False``: outputs only the full sequence of hidden-states corresponding \
                    to the last attention block of shape [batch_size, sequence_length, hidden_size],

                ``pooled_output``: a ``torch.FloatTensor`` of size [batch_size, hidden_size] which is the output of a \
                classifier pre-trained on top of the hidden state associated to the first character of the \
                input (`CLS`) to train on the Next-Sentence task (see XLM's paper).

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
            # or
            all_encoder_layers, pooled_output = model.forward(input_ids, token_type_ids, input_mask)
785
        """
thomwolf's avatar
thomwolf committed
786
787
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
788

789
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
790
791
        outputs = self.pred_layer(output, labels)
        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
792

793
        return outputs
794
795
796
797
798


class XLMForSequenceClassification(XLMPreTrainedModel):
    """XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").

799
    Args:
800
801
802
803
804
805
806
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
        `summary_type`: str, "last", "first", "mean", or "attn". The method
            to pool the input to get a vector representation. Default: last

807
808
809
810
811
812
813
814
815


    Example::

        config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768,
            n_layer=12, num_attention_heads=12, intermediate_size=3072)

        model = modeling.XLMModel(config=config)

816
    """
thomwolf's avatar
xlm  
thomwolf committed
817
    def __init__(self, config):
818
        super(XLMForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
819
        self.num_labels = config.num_labels
820

thomwolf's avatar
xlm  
thomwolf committed
821
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
822
        self.sequence_summary = SequenceSummary(config)
thomwolf's avatar
thomwolf committed
823

824
825
        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
826
827
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, labels=None, head_mask=None):
828
829
        """
        Args:
830
831
832
833
            input_ids: TODO
            lengths: TODO
            positions: TODO
            langs: TODO
834
835
836
            token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
            attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
                but with 1 for real tokens and 0 for padding.
thomwolf's avatar
thomwolf committed
837
                Added for easy compatibility with the XLM model (which uses this negative masking).
838
                You can only uses one among `input_mask` and `attention_mask`
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
            cache: TODO
            labels: TODO
            head_mask: TODO


        Returns:
            A ``tuple(logits_or_loss, new_mems)``. If ``labels`` is ``None``, return token logits with shape
            [batch_size, sequence_length]. If it isn't ``None``, return the ``CrossEntropy`` loss with the targets.

            ``new_mems`` is a list (num layers) of updated mem states at the entry of each layer \
            each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model] \
            Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and ``labels``

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
860
        """
thomwolf's avatar
thomwolf committed
861
862
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
863

864
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
865
        logits = self.sequence_summary(output)
866

thomwolf's avatar
thomwolf committed
867
        outputs = (logits,) + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
868

869
870
871
872
873
874
875
876
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
877
            outputs = (loss,) + outputs
878

879
        return outputs
880
881
882


class XLMForQuestionAnswering(XLMPreTrainedModel):
883
884
    """
    XLM model for Question Answering (span extraction).
885
886
887
    This module is composed of the XLM model with a linear layer on top of
    the sequence output that computes start_logits and end_logits

888
    Args:
889
890
891
892
893
        `config`: a XLMConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False

894
895
896
897
898
899
900
901


    Example::

        config = XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
            num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

        model = XLMForQuestionAnswering(config)
902
    """
thomwolf's avatar
thomwolf committed
903
    def __init__(self, config):
904
        super(XLMForQuestionAnswering, self).__init__(config)
905

thomwolf's avatar
xlm  
thomwolf committed
906
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
907
        self.qa_outputs = SQuADHead(config)
thomwolf's avatar
xlm  
thomwolf committed
908

909
910
        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
911
912
913
    def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
                attention_mask=None, cache=None, start_positions=None, end_positions=None,
                cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
914

915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            input_ids: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
                `run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
            lengths: TODO
            positions: TODO
            langs: TODO
            token_type_ids: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length] with the token
                types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
                a `sentence B` token (see XLM paper for more details).
            attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
                but with 1 for real tokens and 0 for padding.
                Added for easy compatibility with the XLM model (which uses this negative masking).
                You can only uses one among `input_mask` and `attention_mask`
            cache: TODO
            start_positions: position of the first token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
                Positions are clamped to the length of the sequence and position outside of the sequence are not taken
                into account for computing the loss.
            end_positions: position of the last token for the labeled span: ``torch.LongTensor`` of shape [batch_size].
                Positions are clamped to the length of the sequence and position outside of the sequence are not taken
                into account for computing the loss.
            cls_index: TODO
            is_impossible: TODO
            p_mask: TODO
            head_mask: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.

        Returns:
            Either the ``total_loss`` or a ``tuple(start_logits, end_logits)``

                if ``start_positions`` and ``end_positions`` are not ``None``, \
                outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.

                if ``start_positions`` or ``end_positions`` is ``None``:
                Outputs a ``tuple(start_logits, end_logits)`` which are the logits respectively for the start and end
                position tokens of shape [batch_size, sequence_length].

        Example::

            # Already been converted into WordPiece token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
            token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

            start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
            # or
            start_logits, end_logits = model.forward(input_ids, token_type_ids, input_mask)
        """

thomwolf's avatar
thomwolf committed
967
968
        transformer_outputs = self.transformer(input_ids, lengths=lengths, positions=positions, token_type_ids=token_type_ids,
                                               langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask)
969

970
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
971
972
973
974
975

        outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
                                  cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)

        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
976
977

        return outputs