"test/gmock-1.7.0/gtest/codegear/gtest_all.cc" did not exist on "db82302ed06af1c100aa7b626845b4f6cbf955b4"
modeling_xlm.py 39.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# coding=utf-8
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLM model.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import math

import itertools
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss

30
31
32
from .modeling_utils import PreTrainedModel, prune_linear_layer, SequenceSummary, SQuADHead
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings
33
34
35

logger = logging.getLogger(__name__)

36
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
37
    'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
38
39
40
41
42
    'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-pytorch_model.bin",
    'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-pytorch_model.bin",
    'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-pytorch_model.bin",
    'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
    'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
43
44
    'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
    'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
LysandreJik's avatar
LysandreJik committed
45
    'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin",
LysandreJik's avatar
LysandreJik committed
46
    'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin",
47
}
thomwolf's avatar
xlm  
thomwolf committed
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for pos in range(n_pos)
    ])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def gelu(x):
    """
    GELU activation
    https://arxiv.org/abs/1606.08415
    https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
thomwolf's avatar
thomwolf committed
66
    https://github.com/huggingface/pytorch-transformers/blob/master/modeling.py
67
68
69
70
71
    """
    # return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))


thomwolf's avatar
thomwolf committed
72
def get_masks(slen, lengths, causal, padding_mask=None):
73
74
75
76
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    bs = lengths.size(0)
thomwolf's avatar
thomwolf committed
77
78
79
80
81
82
    if padding_mask is not None:
        mask = padding_mask
    else:
        assert lengths.max().item() <= slen
        alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
        mask = alen < lengths[:, None]
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

thomwolf's avatar
thomwolf committed
101
    def __init__(self, n_heads, dim, config):
thomwolf's avatar
thomwolf committed
102
        super(MultiHeadAttention, self).__init__()
103
        self.layer_id = next(MultiHeadAttention.NEW_ID)
thomwolf's avatar
thomwolf committed
104
        self.output_attentions = config.output_attentions
105
106
        self.dim = dim
        self.n_heads = n_heads
thomwolf's avatar
thomwolf committed
107
        self.dropout = config.attention_dropout
108
109
        assert self.dim % self.n_heads == 0

thomwolf's avatar
thomwolf committed
110
111
112
113
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)
114
        self.pruned_heads = set()
115

thomwolf's avatar
thomwolf committed
116
117
118
119
120
    def prune_heads(self, heads):
        attention_head_size = self.dim // self.n_heads
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_heads, attention_head_size)
121
        heads = set(heads) - self.pruned_heads
thomwolf's avatar
thomwolf committed
122
        for head in heads:
123
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
thomwolf's avatar
thomwolf committed
124
125
126
127
128
129
130
131
132
133
134
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q_lin = prune_linear_layer(self.q_lin, index)
        self.k_lin = prune_linear_layer(self.k_lin, index)
        self.v_lin = prune_linear_layer(self.v_lin, index)
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.dim = attention_head_size * self.n_heads
135
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
136

thomwolf's avatar
thomwolf committed
137
    def forward(self, input, mask, kv=None, cache=None, head_mask=None):
138
139
140
141
142
143
144
145
146
147
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
thomwolf's avatar
thomwolf committed
148
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
149
        n_heads = self.n_heads
thomwolf's avatar
thomwolf committed
150
        dim_per_head = self.dim // n_heads
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)                                       # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
187
188
189
190
191

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

192
193
194
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

thomwolf's avatar
xlm  
thomwolf committed
195
196
        outputs = (self.out_lin(context),)
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
197
            outputs = outputs + (weights,)
thomwolf's avatar
xlm  
thomwolf committed
198
        return outputs
199
200
201
202


class TransformerFFN(nn.Module):

thomwolf's avatar
thomwolf committed
203
    def __init__(self, in_dim, dim_hidden, out_dim, config):
thomwolf's avatar
thomwolf committed
204
        super(TransformerFFN, self).__init__()
thomwolf's avatar
thomwolf committed
205
        self.dropout = config.dropout
thomwolf's avatar
thomwolf committed
206
207
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)
thomwolf's avatar
thomwolf committed
208
        self.act = gelu if config.gelu_activation else F.relu
209
210
211
212
213
214
215
216
217

    def forward(self, input):
        x = self.lin1(input)
        x = self.act(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


218
class XLMPreTrainedModel(PreTrainedModel):
219
220
221
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
222
    config_class = XLMConfig
223
    pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
224
    load_tf_weights = None
thomwolf's avatar
thomwolf committed
225
    base_model_prefix = "transformer"
226
227
228

    def __init__(self, *inputs, **kwargs):
        super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
229

230
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
231
232
233
234
235
236
237
238
239
        """ Initialize the weights. """
        if isinstance(module, nn.Embedding):
            if self.config is not None and self.config.embed_init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
        if isinstance(module, nn.Linear):
            if self.config is not None and self.config.init_std is not None:
                nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.)
thomwolf's avatar
thomwolf committed
240
        if isinstance(module, nn.LayerNorm):
241
242
243
244
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


thomwolf's avatar
thomwolf committed
245
246
247
XLM_START_DOCSTRING = r"""    The XLM model was proposed in
    `Cross-lingual Language Model Pretraining`_
    by Guillaume Lample*, Alexis Conneau*. It's a transformer pre-trained using one of the following objectives:
248

thomwolf's avatar
thomwolf committed
249
250
251
        - a causal language modeling (CLM) objective (next token prediction),
        - a masked language modeling (MLM) objective (Bert-like), or
        - a Translation Language Modeling (TLM) object (extension of Bert's MLM to multiple language inputs)
thomwolf's avatar
thomwolf committed
252

thomwolf's avatar
thomwolf committed
253
    Original code can be found `here`_.
thomwolf's avatar
thomwolf committed
254

thomwolf's avatar
thomwolf committed
255
256
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
257

thomwolf's avatar
thomwolf committed
258
259
    .. _`Cross-lingual Language Model Pretraining`:
        https://arxiv.org/abs/1901.07291
thomwolf's avatar
thomwolf committed
260

thomwolf's avatar
thomwolf committed
261
262
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
thomwolf's avatar
thomwolf committed
263

thomwolf's avatar
thomwolf committed
264
265
266
267
268
    .. _`here`:
        https://github.com/facebookresearch/XLM

    Parameters:
        config (:class:`~pytorch_transformers.XLMConfig`): Model configuration class with all the parameters of the model.
269
270
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
271
"""
272

thomwolf's avatar
thomwolf committed
273
274
275
276
XLM_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
277
278
279
280

            XLM is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

thomwolf's avatar
thomwolf committed
281
282
283
            Indices can be obtained using :class:`pytorch_transformers.XLMTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
284
285
286
287
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
288
289
        **langs**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens to be used to indicate the language of each token in the input.
thomwolf's avatar
thomwolf committed
290
291
292
293
            Indices are languages ids which can be obtained from the language names by using two conversion mappings
            provided in the configuration of the model (only provided for multilingual models).
            More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
            the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
294
295
296
297
298
299
300
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
301
302
303
304
305
306
307
308
309
        **lengths**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Length of each sentence that can be used to avoid performing attention on padding token indices.
            You can also use `attention_mask` for the same result (see above), kept here for compatbility.
            Indices selected in ``[0, ..., input_ids.size(-1)]``:
        **cache**:
            dictionary with ``torch.FloatTensor`` that contains pre-computed
            hidden-states (key and values in the attention blocks) as computed by the model
            (see `cache` output below). Can be used to speed up sequential decoding.
            The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
thomwolf's avatar
thomwolf committed
310
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare XLM Model transformer outputing raw hidden-states without any specific head on top.",
                      XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
class XLMModel(XLMPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
327
328
329
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
330
331
332

    Examples::

wangfei's avatar
wangfei committed
333
334
335
336
337
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMModel.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
338
339

    """
340
    ATTRIBUTES = ['encoder', 'eos_index', 'pad_index',  # 'with_output', 
Shijie Wu's avatar
Shijie Wu committed
341
                  'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads', 
342
343
344
345
                  'hidden_dim', 'dropout', 'attention_dropout', 'asm',
                  'asm_cutoffs', 'asm_div_value']

    def __init__(self, config):  #, dico, is_encoder, with_output):
thomwolf's avatar
xlm  
thomwolf committed
346
347
348
        super(XLMModel, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
349
350

        # encoder / decoder, output layer
thomwolf's avatar
thomwolf committed
351
352
353
354
        self.is_encoder = config.is_encoder
        self.is_decoder = not config.is_encoder
        if self.is_decoder:
            raise NotImplementedError("Currently XLM can only be used as an encoder")
355
        # self.with_output = with_output
thomwolf's avatar
xlm  
thomwolf committed
356
        self.causal = config.causal
357
358

        # dictionary / languages
thomwolf's avatar
xlm  
thomwolf committed
359
        self.n_langs = config.n_langs
Shijie Wu's avatar
Shijie Wu committed
360
        self.use_lang_emb = config.use_lang_emb
thomwolf's avatar
xlm  
thomwolf committed
361
362
363
        self.n_words = config.n_words
        self.eos_index = config.eos_index
        self.pad_index = config.pad_index
364
        # self.dico = dico
thomwolf's avatar
thomwolf committed
365
366
        # self.id2lang = config.id2lang
        # self.lang2id = config.lang2id
367
        # assert len(self.dico) == self.n_words
thomwolf's avatar
thomwolf committed
368
        # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
369
370

        # model parameters
thomwolf's avatar
xlm  
thomwolf committed
371
        self.dim = config.emb_dim       # 512 by default
372
        self.hidden_dim = self.dim * 4  # 2048 by default
thomwolf's avatar
xlm  
thomwolf committed
373
374
375
376
        self.n_heads = config.n_heads   # 8 by default
        self.n_layers = config.n_layers
        self.dropout = config.dropout
        self.attention_dropout = config.attention_dropout
377
378
379
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
thomwolf's avatar
thomwolf committed
380
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
thomwolf's avatar
xlm  
thomwolf committed
381
382
        if config.sinusoidal_embeddings:
            create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
Shijie Wu's avatar
Shijie Wu committed
383
        if config.n_langs > 1 and config.use_lang_emb:
thomwolf's avatar
thomwolf committed
384
385
386
            self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
        self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
387
388
389
390
391
392

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
thomwolf's avatar
thomwolf committed
393
394
395
        # if self.is_decoder:
        #     self.layer_norm15 = nn.ModuleList()
        #     self.encoder_attn = nn.ModuleList()
396
397

        for _ in range(self.n_layers):
thomwolf's avatar
thomwolf committed
398
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config))
thomwolf's avatar
thomwolf committed
399
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
400
            # if self.is_decoder:
thomwolf's avatar
thomwolf committed
401
            #     self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
thomwolf's avatar
thomwolf committed
402
403
            #     self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
thomwolf's avatar
thomwolf committed
404
405
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))

LysandreJik's avatar
LysandreJik committed
406
407
        if hasattr(config, "pruned_heads"):
            pruned_heads = config.pruned_heads.copy().items()
408
            config.pruned_heads = {}
LysandreJik's avatar
LysandreJik committed
409
410
411
412
            for layer, heads in pruned_heads:
                if self.attentions[int(layer)].n_heads == config.n_heads:
                    self.prune_heads({int(layer): list(map(int, heads))})

413
        self.init_weights()
414

thomwolf's avatar
thomwolf committed
415
416
    def _resize_token_embeddings(self, new_num_tokens):
        self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
thomwolf's avatar
thomwolf committed
417
        return self.embeddings
thomwolf's avatar
thomwolf committed
418

thomwolf's avatar
thomwolf committed
419
420
421
422
423
424
425
426
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.attentions[layer].prune_heads(heads)

427
428
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None):  # removed: src_enc=None, src_len=None
thomwolf's avatar
thomwolf committed
429
        if lengths is None:
thomwolf's avatar
thomwolf committed
430
            lengths = (input_ids != self.pad_index).sum(dim=1).long()
thomwolf's avatar
xlm  
thomwolf committed
431
        # mask = input_ids != self.pad_index
432
433

        # check inputs
thomwolf's avatar
xlm  
thomwolf committed
434
        bs, slen = input_ids.size()
435
436
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
thomwolf's avatar
xlm  
thomwolf committed
437
        # input_ids = input_ids.transpose(0, 1)  # batch size as dimension 0
thomwolf's avatar
thomwolf committed
438
439
440
441
        # assert (src_enc is None) == (src_len is None)
        # if src_enc is not None:
        #     assert self.is_decoder
        #     assert src_enc.size(0) == bs
442
443

        # generate masks
thomwolf's avatar
thomwolf committed
444
        mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
thomwolf's avatar
thomwolf committed
445
446
        # if self.is_decoder and src_enc is not None:
        #     src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
447

thomwolf's avatar
thomwolf committed
448
449
450
451
        # position_ids
        if position_ids is None:
            position_ids = input_ids.new((slen,)).long()
            position_ids = torch.arange(slen, out=position_ids).unsqueeze(0)
452
        else:
thomwolf's avatar
thomwolf committed
453
454
            assert position_ids.size() == (bs, slen)  # (slen, bs)
            # position_ids = position_ids.transpose(0, 1)
455
456
457

        # langs
        if langs is not None:
thomwolf's avatar
thomwolf committed
458
459
            assert langs.size() == (bs, slen)  # (slen, bs)
            # langs = langs.transpose(0, 1)
460

thomwolf's avatar
thomwolf committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layers

476
477
478
        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
thomwolf's avatar
xlm  
thomwolf committed
479
            input_ids = input_ids[:, -_slen:]
thomwolf's avatar
thomwolf committed
480
            position_ids = position_ids[:, -_slen:]
481
482
483
484
485
486
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
thomwolf's avatar
xlm  
thomwolf committed
487
        tensor = self.embeddings(input_ids)
thomwolf's avatar
thomwolf committed
488
        tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
Shijie Wu's avatar
Shijie Wu committed
489
        if langs is not None and self.use_lang_emb:
490
            tensor = tensor + self.lang_embeddings(langs)
thomwolf's avatar
thomwolf committed
491
492
        if token_type_ids is not None:
            tensor = tensor + self.embeddings(token_type_ids)
493
494
495
496
497
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
thomwolf's avatar
thomwolf committed
498
499
        hidden_states = ()
        attentions = ()
500
        for i in range(self.n_layers):
thomwolf's avatar
thomwolf committed
501
            if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
502
                hidden_states = hidden_states + (tensor,)
503
504

            # self attention
thomwolf's avatar
thomwolf committed
505
506
507
            attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
            attn = attn_outputs[0]
            if self.output_attentions:
thomwolf's avatar
thomwolf committed
508
                attentions = attentions + (attn_outputs[1],)
509
510
511
512
513
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
thomwolf's avatar
thomwolf committed
514
515
516
517
518
            # if self.is_decoder and src_enc is not None:
            #     attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
            #     attn = F.dropout(attn, p=self.dropout, training=self.training)
            #     tensor = tensor + attn
            #     tensor = self.layer_norm15[i](tensor)
519
520
521
522
523
524

            # FFN
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)
            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

thomwolf's avatar
thomwolf committed
525
526
        # Add last hidden state
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
527
            hidden_states = hidden_states + (tensor,)
thomwolf's avatar
thomwolf committed
528

529
530
531
532
533
        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
thomwolf's avatar
thomwolf committed
534
        # tensor = tensor.transpose(0, 1)
535

thomwolf's avatar
thomwolf committed
536
        outputs = (tensor,)
537
        if self.output_hidden_states:
thomwolf's avatar
thomwolf committed
538
            outputs = outputs + (hidden_states,)
thomwolf's avatar
thomwolf committed
539
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
540
            outputs = outputs + (attentions,)
thomwolf's avatar
thomwolf committed
541
        return outputs  # outputs, (hidden_states), (attentions)
542
543
544
545
546
547


class XLMPredLayer(nn.Module):
    """
    Prediction layer (cross_entropy or adaptive_softmax).
    """
thomwolf's avatar
xlm  
thomwolf committed
548
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
549
        super(XLMPredLayer, self).__init__()
thomwolf's avatar
xlm  
thomwolf committed
550
551
552
553
        self.asm = config.asm
        self.n_words = config.n_words
        self.pad_index = config.pad_index
        dim = config.emb_dim
554

thomwolf's avatar
xlm  
thomwolf committed
555
        if config.asm is False:
thomwolf's avatar
thomwolf committed
556
            self.proj = nn.Linear(dim, config.n_words, bias=True)
557
558
559
        else:
            self.proj = nn.AdaptiveLogSoftmaxWithLoss(
                in_features=dim,
thomwolf's avatar
xlm  
thomwolf committed
560
561
562
                n_classes=config.n_words,
                cutoffs=config.asm_cutoffs,
                div_value=config.asm_div_value,
563
564
565
                head_bias=True,  # default is False
            )

thomwolf's avatar
thomwolf committed
566
567
    def forward(self, x, y=None):
        """ Compute the loss, and optionally the scores.
568
        """
thomwolf's avatar
thomwolf committed
569
        outputs = ()
570
571
        if self.asm is False:
            scores = self.proj(x).view(-1, self.n_words)
thomwolf's avatar
thomwolf committed
572
573
574
575
            outputs = (scores,) + outputs
            if y is not None:
                loss = F.cross_entropy(scores, y, reduction='elementwise_mean')
                outputs = (loss,) + outputs
576
        else:
thomwolf's avatar
thomwolf committed
577
578
579
580
581
            scores = self.proj.log_prob(x)
            outputs = (scores,) + outputs
            if y is not None:
                _, loss = self.proj(x, y)
                outputs = (loss,) + outputs
582

thomwolf's avatar
thomwolf committed
583
        return outputs
584

thomwolf's avatar
thomwolf committed
585

thomwolf's avatar
thomwolf committed
586
587
588
@add_start_docstrings("""The XLM Model transformer with a language modeling head on top
    (linear layer with weights tied to the input embeddings). """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
589
class XLMWithLMHeadModel(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-1`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
607
608
609
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
610
611
612

    Examples::

wangfei's avatar
wangfei committed
613
614
615
616
617
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMWithLMHeadModel.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
618

thomwolf's avatar
xlm  
thomwolf committed
619
620
    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
621
        super(XLMWithLMHeadModel, self).__init__(config)
thomwolf's avatar
xlm  
thomwolf committed
622
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
623
        self.pred_layer = XLMPredLayer(config)
624

625
        self.init_weights()
626
627
628
629
630
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the embeddings
        """
thomwolf's avatar
thomwolf committed
631
        self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
632

633
634
635
636
637
638
639
640
641
642
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
643

644
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
645
646
        outputs = self.pred_layer(output, labels)
        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
647

648
        return outputs
649
650


thomwolf's avatar
thomwolf committed
651
652
653
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks. """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
654
class XLMForSequenceClassification(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
655
656
657
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
LysandreJik's avatar
LysandreJik committed
658
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
659
660
661
662
663
664
665
666
667
668
669
670
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
671
672
673
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
674
675
676

    Examples::

wangfei's avatar
wangfei committed
677
678
679
680
681
682
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMForSequenceClassification.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
683

684
    """
thomwolf's avatar
xlm  
thomwolf committed
685
    def __init__(self, config):
686
        super(XLMForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
687
        self.num_labels = config.num_labels
688

thomwolf's avatar
xlm  
thomwolf committed
689
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
690
        self.sequence_summary = SequenceSummary(config)
thomwolf's avatar
thomwolf committed
691

692
        self.init_weights()
693

694
695
696
697
698
699
700
701
702
703
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
704

705
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
706
        logits = self.sequence_summary(output)
707

thomwolf's avatar
thomwolf committed
708
        outputs = (logits,) + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
709

710
711
712
713
714
715
716
717
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
718
            outputs = (loss,) + outputs
719

720
        return outputs
721
722


thomwolf's avatar
thomwolf committed
723
724
725
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
726
class XLMForQuestionAnswering(XLMPreTrainedModel):
thomwolf's avatar
thomwolf committed
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    r"""
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels whether a question has an answer or no answer (SQuAD 2.0)
        **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
        **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) 

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
754
755
756
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
757
758
759

    Examples::

wangfei's avatar
wangfei committed
760
761
762
763
764
765
766
        tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
        model = XLMForQuestionAnswering.from_pretrained('xlm-mlm-en-2048')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        start_positions = torch.tensor([1])
        end_positions = torch.tensor([3])
        outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
        loss, start_scores, end_scores = outputs[:2]
767
768

    """
thomwolf's avatar
thomwolf committed
769
    def __init__(self, config):
770
        super(XLMForQuestionAnswering, self).__init__(config)
771

thomwolf's avatar
xlm  
thomwolf committed
772
        self.transformer = XLMModel(config)
thomwolf's avatar
thomwolf committed
773
        self.qa_outputs = SQuADHead(config)
thomwolf's avatar
xlm  
thomwolf committed
774

775
        self.init_weights()
776

777
778
779
780
781
782
783
784
785
786
787
    def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
                lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None,
                is_impossible=None, cls_index=None, p_mask=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               langs=langs,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
                                               lengths=lengths, 
                                               cache=cache,
                                               head_mask=head_mask)
788

789
        output = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
790
791
792
793
794

        outputs = self.qa_outputs(output, start_positions=start_positions, end_positions=end_positions,
                                  cls_index=cls_index, is_impossible=is_impossible, p_mask=p_mask)

        outputs = outputs + transformer_outputs[1:]  # Keep new_mems and attention/hidden states if they are here
795
796

        return outputs