modeling_xlm.py 46.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM model.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import sys
from io import open

import itertools
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss

thomwolf's avatar
thomwolf committed
33
from .modeling_utils import (PretrainedConfig, PreTrainedModel, add_start_docstrings,
34
                             prune_linear_layer, SequenceSummary, SQuADHead)
35
36
37

logger = logging.getLogger(__name__)

38
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
39
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
40
41
42
43
44
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
45
46
    'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
    'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
LysandreJik's avatar
LysandreJik committed
47
    'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin",
LysandreJik's avatar
LysandreJik committed
48
    'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin",
49
}
50
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
51
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
52
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json",
thomwolf's avatar
thomwolf committed
53
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json",
54
55
56
57
58
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
    'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
    'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
59
    'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
LysandreJik's avatar
LysandreJik committed
60
    'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json",
61
62
63
64
65
}


class XLMConfig(PretrainedConfig):
    """Configuration class to store the configuration of a `XLMModel`.
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

    Args:
        vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
        d_model: Size of the encoder layers and the pooler layer.
        n_layer: Number of hidden layers in the Transformer encoder.
        n_head: Number of attention heads for each attention layer in
            the Transformer encoder.
        d_inner: The size of the "intermediate" (i.e., feed-forward)
            layer in the Transformer encoder.
        ff_activation: The non-linear activation function (function or string) in the
            encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
        untie_r: untie relative position biases
        attn_type: 'bi' for XLM, 'uni' for Transformer-XL

        dropout: The dropout probabilitiy for all fully connected
            layers in the embeddings, encoder, and pooler.
        dropatt: The dropout ratio for the attention
            probabilities.
        max_position_embeddings: The maximum sequence length that this model might
            ever be used with. Typically set this to something large just in case
            (e.g., 512 or 1024 or 2048).
        initializer_range: The sttdev of the truncated_normal_initializer for
            initializing all weight matrices.
        layer_norm_eps: The epsilon used by LayerNorm.

        dropout: float, dropout rate.
        dropatt: float, dropout rate on attention probabilities.
        init: str, the initialization scheme, either "normal" or "uniform".
        init_range: float, initialize the parameters with a uniform distribution
            in [-init_range, init_range]. Only effective when init="uniform".
        init_std: float, initialize the parameters with a normal distribution
            with mean 0 and stddev init_std. Only effective when init="normal".
        mem_len: int, the number of tokens to cache.
        reuse_len: int, the number of tokens in the currect batch to be cached
            and reused in the future.
        bi_data: bool, whether to use bidirectional input pipeline.
            Usually set to True during pretraining and False during finetuning.
        clamp_len: int, clamp all relative distances larger than clamp_len.
            -1 means no clamping.
        same_length: bool, whether to use the same attention length for each token.
106
    """
107
    pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
108
109

    def __init__(self,
thomwolf's avatar
thomwolf committed
110
                 vocab_size_or_config_json_file=30145,
thomwolf's avatar
xlm  
thomwolf committed
111
112
113
114
115
116
117
                 emb_dim=2048,
                 n_layers=12,
                 n_heads=16,
                 dropout=0.1,
                 attention_dropout=0.1,
                 gelu_activation=True,
                 sinusoidal_embeddings=False,
thomwolf's avatar
thomwolf committed
118
                 causal=False,
thomwolf's avatar
xlm  
thomwolf committed
119
120
                 asm=False,
                 n_langs=1,
Shijie Wu's avatar
Shijie Wu committed
121
                 use_lang_emb=True,
122
                 max_position_embeddings=512,
thomwolf's avatar
thomwolf committed
123
                 embed_init_std=2048 ** -0.5,
thomwolf's avatar
thomwolf committed
124
                 layer_norm_eps=1e-12,
thomwolf's avatar
thomwolf committed
125
126
127
128
129
130
131
                 init_std=0.02,
                 bos_index=0,
                 eos_index=1,
                 pad_index=2,
                 unk_index=3,
                 mask_index=5,
                 is_encoder=True,
thomwolf's avatar
thomwolf committed
132
133
134

                 finetuning_task=None,
                 num_labels=2,
135
                 summary_type='first',
thomwolf's avatar
thomwolf committed
136
                 summary_use_proj=True,
137
138
139
                 summary_activation=None,
                 summary_proj_to_labels=True,
                 summary_first_dropout=0.1,
thomwolf's avatar
thomwolf committed
140
141
                 start_n_top=5,
                 end_n_top=5,
thomwolf's avatar
xlm  
thomwolf committed
142
                 **kwargs):
143
144
        """Constructs XLMConfig.
        """
thomwolf's avatar
xlm  
thomwolf committed
145
146
        super(XLMConfig, self).__init__(**kwargs)

147
148
149
150
151
152
153
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
thomwolf's avatar
xlm  
thomwolf committed
154
155
156
157
158
159
160
161
            self.n_words = vocab_size_or_config_json_file
            self.emb_dim = emb_dim
            self.n_layers = n_layers
            self.n_heads = n_heads
            self.dropout = dropout
            self.attention_dropout = attention_dropout
            self.gelu_activation = gelu_activation
            self.sinusoidal_embeddings = sinusoidal_embeddings
thomwolf's avatar
thomwolf committed
162
            self.causal = causal
thomwolf's avatar
xlm  
thomwolf committed
163
164
            self.asm = asm
            self.n_langs = n_langs
Shijie Wu's avatar
Shijie Wu committed
165
            self.use_lang_emb = use_lang_emb
thomwolf's avatar
thomwolf committed
166
            self.layer_norm_eps = layer_norm_eps
thomwolf's avatar
thomwolf committed
167
168
169
170
171
172
            self.bos_index = bos_index
            self.eos_index = eos_index
            self.pad_index = pad_index
            self.unk_index = unk_index
            self.mask_index = mask_index
            self.is_encoder = is_encoder
173
            self.max_position_embeddings = max_position_embeddings
thomwolf's avatar
thomwolf committed
174
175
            self.embed_init_std = embed_init_std
            self.init_std = init_std
thomwolf's avatar
thomwolf committed
176
177
178
179
180
            self.finetuning_task = finetuning_task
            self.num_labels = num_labels
            self.summary_type = summary_type
            self.summary_use_proj = summary_use_proj
            self.summary_activation = summary_activation
181
182
            self.summary_proj_to_labels = summary_proj_to_labels
            self.summary_first_dropout = summary_first_dropout
thomwolf's avatar
thomwolf committed
183
184
            self.start_n_top = start_n_top
            self.end_n_top = end_n_top
185
186
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
VictorSanh's avatar
VictorSanh committed
187
                             " or the path to a pretrained model config file (str)")
188

thomwolf's avatar
xlm  
thomwolf committed
189
    @property
thomwolf's avatar
thomwolf committed
190
191
    def vocab_size(self):
        return self.n_words
thomwolf's avatar
xlm  
thomwolf committed
192

thomwolf's avatar
thomwolf committed
193
194
195
196
    @vocab_size.setter
    def vocab_size(self, value):
        self.n_words = value

thomwolf's avatar
xlm  
thomwolf committed
197
198
199
200
201
202
203
204
205
206
207
208
    @property
    def hidden_size(self):
        return self.emb_dim

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def gelu(x):
    """
    GELU activation
    https://arxiv.org/abs/1606.08415
    https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
thomwolf's avatar
thomwolf committed
226
    https://github.com/huggingface/pytorch-transformers/blob/master/modeling.py
227
228
229
230
231
    """
    # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))


thomwolf's avatar
thomwolf committed
232
def get_masks(slen, lengths, causal, padding_mask=None):
233
234
235
236
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    bs = lengths.size(0)
thomwolf's avatar
thomwolf committed
237
238
239
240
241
242
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
        mask = alen < lengths[:, None]
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

thomwolf's avatar
thomwolf committed
261
    def __init__(self, n_heads, dim, config):
thomwolf's avatar
thomwolf committed
262
        super(MultiHeadAttention, self).__init__()
263
        self.layer_id = next(MultiHeadAttention.NEW_ID)
thomwolf's avatar
thomwolf committed
264
        self.output_attentions = config.output_attentions
265
266
        self.dim = dim
        self.n_heads = n_heads
thomwolf's avatar
thomwolf committed
267
        self.dropout = config.attention_dropout
268
269
        assert self.dim % self.n_heads == 0

thomwolf's avatar
thomwolf committed
270
271
272
273
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
274
        self.pruned_heads = set()
275

thomwolf's avatar
thomwolf committed
276
277
278
279
280
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_heads, attention_head_size)
281
        heads = set(heads) - self.pruned_heads
thomwolf's avatar
thomwolf committed
282
        for head in heads:
283
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
thomwolf's avatar
thomwolf committed
284
285
286
287
288
289
290
291
292
293
294
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
295
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
296

thomwolf's avatar
thomwolf committed
297
    def forward(self, input, mask, kv=None, cache=None, head_mask=None):
298
299
300
301
302
303
304
305
306
307
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
thomwolf's avatar
thomwolf committed
308
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
309
        n_heads = self.n_heads
thomwolf's avatar
thomwolf committed
310
        dim_per_head = self.dim // n_heads
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)                                       # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
347
348
349
350
351

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

352
353
354
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

thomwolf's avatar
xlm  
thomwolf committed
355
356
        outputs = (self.out_lin(context),)
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
357
            outputs = outputs + (weights,)
thomwolf's avatar
xlm  
thomwolf committed
358
        return outputs
359
360
361
362


class TransformerFFN(nn.Module):

thomwolf's avatar
thomwolf committed
363
    def __init__(self, in_dim, dim_hidden, out_dim, config):
thomwolf's avatar
thomwolf committed
364
        super(TransformerFFN, self).__init__()
thomwolf's avatar
thomwolf committed
365
        self.dropout = config.dropout
thomwolf's avatar
thomwolf committed
366
367
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)
thomwolf's avatar
thomwolf committed
368
        self.act = gelu if config.gelu_activation else F.relu
369
370
371
372
373
374
375
376
377

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


378
class XLMPreTrainedModel(PreTrainedModel):
379
380
381
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
382
    config_class = XLMConfig
383
    pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
384
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
385
    base_model_prefix = "transformer"
386
387
388

    def __init__(self, *inputs, **kwargs):
        super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
389

390
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
391
392
393
394
395
396
397
398
399
        """ Initialize the weights. """
        if isinstance(module, nn.Embedding):
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
        if isinstance(module, nn.Linear):
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.)
thomwolf's avatar
thomwolf committed
400
        if isinstance(module, nn.LayerNorm):
401
402
403
404
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


thomwolf's avatar
thomwolf committed
405
406
407
XLM_START_DOCSTRING = r"""    The XLM model was proposed in
    `Cross-lingual Language Model Pretraining`_
    by Guillaume Lample*, Alexis Conneau*. It's a transformer pre-trained using one of the following objectives:
408

thomwolf's avatar
thomwolf committed
409
410
411
        - a causal language modeling (CLM) objective (next token prediction),
        - a masked language modeling (MLM) objective (Bert-like), or
        - a Translation Language Modeling (TLM) object (extension of Bert's MLM to multiple language inputs)
thomwolf's avatar
thomwolf committed
412

thomwolf's avatar
thomwolf committed
413
    Original code can be found `here`_.
thomwolf's avatar
thomwolf committed
414

thomwolf's avatar
thomwolf committed
415
416
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
417

thomwolf's avatar
thomwolf committed
418
419
    .. _`Cross-lingual Language Model Pretraining`:
        https://arxiv.org/abs/1901.07291
thomwolf's avatar
thomwolf committed
420

thomwolf's avatar
thomwolf committed
421
422
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
thomwolf's avatar
thomwolf committed
423

thomwolf's avatar
thomwolf committed
424
425
426
427
428
    .. _`here`:
        https://github.com/facebookresearch/XLM

    Parameters:
        config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model.
429
430
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
431
"""
432

thomwolf's avatar
thomwolf committed
433
434
435
436
XLM_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
437
438
439
440

            XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

thomwolf's avatar
thomwolf committed
441
442
443
            Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
444
445
446
447
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
448
449
        **langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens to be used to indicate the language of each token in the input.
thomwolf's avatar
thomwolf committed
450
451
452
453
            Indices are languages ids which can be obtained from the language names by using two conversion mappings
            provided in the configuration of the model (only provided for multilingual models).
            More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
            the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
454
455
456
457
458
459
460
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
461
462
463
464
465
466
467
468
469
        **lengths**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Length of each sentence that can be used to avoid performing attention on padding token indices.
            You can also use `attention_mask` for the same result (see above), kept here for compatbility.
            Indices selected in ``[0, ..., input_ids.size(-1)]``:
        **cache**:
            dictionary with ``torch.FloatTensor`` that contains pre-computed
            hidden-states (key and values in the attention blocks) as computed by the model
            (see `cache` output below). Can be used to speed up sequential decoding.
            The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
thomwolf's avatar
thomwolf committed
470
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
                      XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
class XLMModel(XLMPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
487
488
489
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
490
491
492

    Examples::

wangfei's avatar
wangfei committed
493
494
495
496
497
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMModel.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
498
499

    """
500
    ATTRIBUTES = ['encoder', 'eos_index', 'pad_index',  # 'with_output', 
Shijie Wu's avatar
Shijie Wu committed
501
                  'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads', 
502
503
504
505
                  'hidden_dim', 'dropout', 'attention_dropout', 'asm',
                  'asm_cutoffs', 'asm_div_value']

    def __init__(self, config):  #, dico, is_encoder, with_output):
thomwolf's avatar
xlm  
thomwolf committed
506
507
508
        super(XLMModel, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
509
510

        # encoder / decoder, output layer
thomwolf's avatar
thomwolf committed
511
512
513
514
        self.is_encoder = config.is_encoder
        self.is_decoder = not config.is_encoder
        if self.is_decoder:
            raise NotImplementedError("Currently XLM can only be used as an encoder")
515
        # self.with_output = with_output
thomwolf's avatar
xlm  
thomwolf committed
516
        self.causal = config.causal
517
518

        # dictionary / languages
thomwolf's avatar
xlm  
thomwolf committed
519
        self.n_langs = config.n_langs
Shijie Wu's avatar
Shijie Wu committed
520
        self.use_lang_emb = config.use_lang_emb
thomwolf's avatar
xlm  
thomwolf committed
521
522
523
        self.n_words = config.n_words
        self.eos_index = config.eos_index
        self.pad_index = config.pad_index
524
        # self.dico = dico
thomwolf's avatar
thomwolf committed
525
526
        # self.id2lang = config.id2lang
        # self.lang2id = config.lang2id
527
        # assert len(self.dico) == self.n_words
thomwolf's avatar
thomwolf committed
528
        # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
529
530

        # model parameters
thomwolf's avatar
xlm  
thomwolf committed
531
        self.dim = config.emb_dim       # 512 by default
532
        self.hidden_dim = self.dim * 4  # 2048 by default
thomwolf's avatar
xlm  
thomwolf committed
533
534
535
536
        self.n_heads = config.n_heads   # 8 by default
        self.n_layers = config.n_layers
        self.dropout = config.dropout
        self.attention_dropout = config.attention_dropout
537
538
539
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
thomwolf's avatar
thomwolf committed
540
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
thomwolf's avatar
xlm  
thomwolf committed
541
542
        if config.sinusoidal_embeddings:
            create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
Shijie Wu's avatar
Shijie Wu committed
543
        if config.n_langs > 1 and config.use_lang_emb:
thomwolf's avatar
thomwolf committed
544
545
546
            self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
        self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
547
548
549
550
551
552

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
thomwolf's avatar
thomwolf committed
553
554
555
        # if self.is_decoder:
        #     self.layer_norm15 = nn.ModuleList()
        #     self.encoder_attn = nn.ModuleList()
556
557

        for _ in range(self.n_layers):
thomwolf's avatar
thomwolf committed
558
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
thomwolf's avatar
thomwolf committed
559
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
560
            # if self.is_decoder:
thomwolf's avatar
thomwolf committed
561
            #     self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
562
563
            #     self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
thomwolf's avatar
thomwolf committed
564
565
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))

LysandreJik's avatar
LysandreJik committed
566
567
        if hasattr(config, "pruned_heads"):
            pruned_heads = config.pruned_heads.copy().items()
568
            config.pruned_heads = {}
LysandreJik's avatar
LysandreJik committed
569
570
571
572
            for layer, heads in pruned_heads:
                if self.attentions[int(layer)].n_heads == config.n_heads:
                    self.prune_heads({int(layer): list(map(int, heads))})

573
        self.init_weights()
574

thomwolf's avatar
thomwolf committed
575
576
    def _resize_token_embeddings(self, new_num_tokens):
        self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
thomwolf's avatar
thomwolf committed
577
        return self.embeddings
thomwolf's avatar
thomwolf committed
578

thomwolf's avatar
thomwolf committed
579
580
581
582
583
584
585
586
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.attentions[layer].prune_heads(heads)

587
588
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None):  # removed: src_enc=None, src_len=None
thomwolf's avatar
thomwolf committed
589
        if lengths is None:
thomwolf's avatar
thomwolf committed
590
            lengths = (input_ids != self.pad_index).sum(dim=1).long()
thomwolf's avatar
xlm  
thomwolf committed
591
        # mask = input_ids != self.pad_index
592
593

        # check inputs
thomwolf's avatar
xlm  
thomwolf committed
594
        bs, slen = input_ids.size()
595
596
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
thomwolf's avatar
xlm  
thomwolf committed
597
        # input_ids = input_ids.transpose(0, 1)  # batch size as dimension 0
thomwolf's avatar
thomwolf committed
598
599
600
601
        # assert (src_enc is None) == (src_len is None)
        # if src_enc is not None:
        #     assert self.is_decoder
        #     assert src_enc.size(0) == bs
602
603

        # generate masks
thomwolf's avatar
thomwolf committed
604
        mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
thomwolf's avatar
thomwolf committed
605
606
        # if self.is_decoder and src_enc is not None:
        #     src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
607

thomwolf's avatar
thomwolf committed
608
609
610
611
        # position_ids
        if position_ids is None:
            position_ids = input_ids.new((slen,)).long()
            position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
612
        else:
thomwolf's avatar
thomwolf committed
613
614
            assert position_ids.size() == (bs, slen)  # (slen, bs)
            # position_ids = position_ids.transpose(0, 1)
615
616
617

        # langs
        if langs is not None:
thomwolf's avatar
thomwolf committed
618
619
            assert langs.size() == (bs, slen)  # (slen, bs)
            # langs = langs.transpose(0, 1)
620

thomwolf's avatar
thomwolf committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layers

636
637
638
        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
thomwolf's avatar
xlm  
thomwolf committed
639
            input_ids = input_ids[:, -_slen:]
thomwolf's avatar
thomwolf committed
640
            position_ids = position_ids[:, -_slen:]
641
642
643
644
645
646
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
thomwolf's avatar
xlm  
thomwolf committed
647
        tensor = self.embeddings(input_ids)
thomwolf's avatar
thomwolf committed
648
        tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
Shijie Wu's avatar
Shijie Wu committed
649
        if langs is not None and self.use_lang_emb:
650
            tensor = tensor + self.lang_embeddings(langs)
thomwolf's avatar
thomwolf committed
651
652
        if token_type_ids is not None:
            tensor = tensor + self.embeddings(token_type_ids)
653
654
655
656
657
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
thomwolf's avatar
thomwolf committed
658
659
        hidden_states = ()
        attentions = ()
660
        for i in range(self.n_layers):
thomwolf's avatar
thomwolf committed
661
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
662
                hidden_states = hidden_states + (tensor,)
663
664

            # self attention
thomwolf's avatar
thomwolf committed
665
666
667
            attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
            attn = attn_outputs[0]
            if self.output_attentions:
thomwolf's avatar
thomwolf committed
668
                attentions = attentions + (attn_outputs[1],)
669
670
671
672
673
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
thomwolf's avatar
thomwolf committed
674
675
676
677
678
            # if self.is_decoder and src_enc is not None:
            #     attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
            #     attn = F.dropout(attn, p=self.dropout, training=self.training)
            #     tensor = tensor + attn
            #     tensor = self.layer_norm15[i](tensor)
679
680
681
682
683
684

            # FFN
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)
            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

thomwolf's avatar
thomwolf committed
685
686
        # Add last hidden state
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
687
            hidden_states = hidden_states + (tensor,)
thomwolf's avatar
thomwolf committed
688

689
690
691
692
693
        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
thomwolf's avatar
thomwolf committed
694
        # tensor = tensor.transpose(0, 1)
695

thomwolf's avatar
thomwolf committed
696
        outputs = (tensor,)
697
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
698
            outputs = outputs + (hidden_states,)
thomwolf's avatar
thomwolf committed
699
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
700
            outputs = outputs + (attentions,)
thomwolf's avatar
thomwolf committed
701
        return outputs  # outputs, (hidden_states), (attentions)
702
703
704
705
706
707


class XLMPredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
thomwolf's avatar
xlm  
thomwolf committed
708
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
709
        super(XLMPredLayer, self).__init__()
thomwolf's avatar
xlm  
thomwolf committed
710
711
712
713
        self.asm = config.asm
        self.n_words = config.n_words
        self.pad_index = config.pad_index
        dim = config.emb_dim
714

thomwolf's avatar
xlm  
thomwolf committed
715
        if config.asm is False:
thomwolf's avatar
thomwolf committed
716
            self.proj = nn.Linear(dim, config.n_words, bias=True)
717
718
719
        else:
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
thomwolf's avatar
xlm  
thomwolf committed
720
721
722
                n_classes=config.n_words,
                cutoffs=config.asm_cutoffs,
                div_value=config.asm_div_value,
723
724
725
                head_bias=True,  # default is False
            )

thomwolf's avatar
thomwolf committed
726
727
    def forward(self, x, y=None):
        """ Compute the loss, and optionally the scores.
728
        """
thomwolf's avatar
thomwolf committed
729
        outputs = ()
730
731
        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
thomwolf's avatar
thomwolf committed
732
733
734
735
            outputs = (scores,) + outputs
            if y is not None:
                loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
                outputs = (loss,) + outputs
736
        else:
thomwolf's avatar
thomwolf committed
737
738
739
740
741
            scores = self.proj.log_prob(x)
            outputs = (scores,) + outputs
            if y is not None:
                _, loss = self.proj(x, y)
                outputs = (loss,) + outputs
742

thomwolf's avatar
thomwolf committed
743
        return outputs
744

thomwolf's avatar
thomwolf committed
745

thomwolf's avatar
thomwolf committed
746
747
748
@add_start_docstrings("""The XLM Model transformer with a language modeling head on top
    (linear layer with weights tied to the input embeddings). """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
749
class XLMWithLMHeadModel(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-1`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
767
768
769
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
770
771
772

    Examples::

wangfei's avatar
wangfei committed
773
774
775
776
777
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMWithLMHeadModel.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
778

thomwolf's avatar
xlm  
thomwolf committed
779
780
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
781
        super(XLMWithLMHeadModel, self).__init__(config)
thomwolf's avatar
xlm  
thomwolf committed
782
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
783
        self.pred_layer = XLMPredLayer(config)
784

785
        self.init_weights()
786
787
788
789
790
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the embeddings
        """
thomwolf's avatar
thomwolf committed
791
        self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
792

793
794
795
796
797
798
799
800
801
802
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
803

804
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
805
806
        outputs = self.pred_layer(output, labels)
        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
807

808
        return outputs
809
810


thomwolf's avatar
thomwolf committed
811
812
813
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks. """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
814
class XLMForSequenceClassification(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
815
816
817
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
LysandreJik's avatar
LysandreJik committed
818
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
819
820
821
822
823
824
825
826
827
828
829
830
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
831
832
833
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
834
835
836

    Examples::

wangfei's avatar
wangfei committed
837
838
839
840
841
842
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMForSequenceClassification.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
843

844
    """
thomwolf's avatar
xlm  
thomwolf committed
845
    def __init__(self, config):
846
        super(XLMForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
847
        self.num_labels = config.num_labels
848

thomwolf's avatar
xlm  
thomwolf committed
849
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
850
        self.sequence_summary = SequenceSummary(config)
thomwolf's avatar
thomwolf committed
851

852
        self.init_weights()
853

854
855
856
857
858
859
860
861
862
863
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
864

865
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
866
        logits = self.sequence_summary(output)
867

thomwolf's avatar
thomwolf committed
868
        outputs = (logits,) + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
869

870
871
872
873
874
875
876
877
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
878
            outputs = (loss,) + outputs
879

880
        return outputs
881
882


thomwolf's avatar
thomwolf committed
883
884
885
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
886
class XLMForQuestionAnswering(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
    r"""
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels whether a question has an answer or no answer (SQuAD 2.0)
        **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
        **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) 

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
914
915
916
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
917
918
919

    Examples::

wangfei's avatar
wangfei committed
920
921
922
923
924
925
926
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMForQuestionAnswering.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        start_positions = torch.tensor([1])
        end_positions = torch.tensor([3])
        outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
        loss, start_scores, end_scores = outputs[:2]
927
928

    """
thomwolf's avatar
thomwolf committed
929
    def __init__(self, config):
930
        super(XLMForQuestionAnswering, self).__init__(config)
931

thomwolf's avatar
xlm  
thomwolf committed
932
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
933
        self.qa_outputs = SQuADHead(config)
thomwolf's avatar
xlm  
thomwolf committed
934

935
        self.init_weights()
936

937
938
939
940
941
942
943
944
945
946
947
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None,
                is_impossible=None, cls_index=None, p_mask=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
948

949
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
950
951
952
953
954

        outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
                                  cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)

        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
955
956

        return outputs