modeling_bart.py 55.6 KB
Newer Older
Sam Shleifer's avatar
Sam Shleifer committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
Sam Shleifer's avatar
Sam Shleifer committed
17
import math
Sam Shleifer's avatar
Sam Shleifer committed
18
import random
Sylvain Gugger's avatar
Sylvain Gugger committed
19
import warnings
Sam Shleifer's avatar
Sam Shleifer committed
20
21
from typing import Dict, List, Optional, Tuple

22
import numpy as np
Sam Shleifer's avatar
Sam Shleifer committed
23
24
25
import torch
import torch.nn.functional as F
from torch import Tensor, nn
Suraj Patil's avatar
Suraj Patil committed
26
from torch.nn import CrossEntropyLoss
Sam Shleifer's avatar
Sam Shleifer committed
27

28
from .activations import ACT2FN
Sam Shleifer's avatar
Sam Shleifer committed
29
from .configuration_bart import BartConfig
30
31
32
33
34
from .file_utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
35
36
37
38
39
40
41
42
43
    replace_return_docstrings,
)
from .modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
44
)
45
from .modeling_utils import PreTrainedModel
Sam Shleifer's avatar
Sam Shleifer committed
46
47
48
49


logger = logging.getLogger(__name__)

50
_CONFIG_FOR_DOC = "BartConfig"
51
52
_TOKENIZER_FOR_DOC = "BartTokenizer"

Sam Shleifer's avatar
Sam Shleifer committed
53

54
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
55
    "facebook/bart-base",
56
57
58
59
60
61
62
63
    "facebook/bart-large",
    "facebook/bart-large-mnli",
    "facebook/bart-large-cnn",
    "facebook/bart-large-xsum",
    "facebook/mbart-large-en-ro",
    # See all BART models at https://huggingface.co/models?filter=bart
]

Sam Shleifer's avatar
Sam Shleifer committed
64
65
66
67
68
69
70
71
72
73
74

BART_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matters related to general usage and behavior.

    Parameters:
        config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.

75
76
"""
BART_GENERATION_EXAMPLE = r"""
77
    Summarization example::
78
79

        from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
80

81
        # see ``examples/summarization/bart/run_eval.py`` for a longer example
Zihao Fu's avatar
Zihao Fu committed
82
83
        model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
84

85
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
86
87
        inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')

88
        # Generate Summary
89
        summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
90
91
        print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

Sam Shleifer's avatar
Sam Shleifer committed
92
93
94
95
96
97
98
99
100
101
102
103
"""

BART_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
               Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
            Padding will be ignored by default should you provide it.
            Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
        attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices in input_ids.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
104
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Patrick von Platen's avatar
Patrick von Platen committed
105
106
107
            Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
            `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
            Used in the cross-attention of the decoder.
Sam Shleifer's avatar
Sam Shleifer committed
108
109
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
            Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
110
111
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
            Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
Sam Shleifer's avatar
Sam Shleifer committed
112
113
            If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
            See diagram 1 in the paper for more info on the default strategy
114
115
116
117
118
119
120
121
122
        decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains pre-computed key and value hidden-states of the attention blocks.
            Can be used to speed up decoding.
            If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
            ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
            :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
            If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
            ``decoder_past_key_values``).
ZhuBaohe's avatar
ZhuBaohe committed
123
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
124
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
125
126
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
127
128
129
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
            plain tuple.
Sam Shleifer's avatar
Sam Shleifer committed
130
"""
131
132
133


def invert_mask(attention_mask):
134
    """Turns 1->0, 0->1, False->True, True-> False"""
135
136
    assert attention_mask.dim() == 2
    return attention_mask.eq(0)
Sam Shleifer's avatar
Sam Shleifer committed
137
138
139


def _prepare_bart_decoder_inputs(
140
    config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
Sam Shleifer's avatar
Sam Shleifer committed
141
):
142
    """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
Sam Shleifer's avatar
Sam Shleifer committed
143
    none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
144
    Note: this is not called during generation
Sam Shleifer's avatar
Sam Shleifer committed
145
146
147
148
    """
    pad_token_id = config.pad_token_id
    if decoder_input_ids is None:
        decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
149
150
    bsz, tgt_len = decoder_input_ids.size()
    if decoder_padding_mask is None:
Sam Shleifer's avatar
Sam Shleifer committed
151
        decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
152
153
154
155
156
157
    else:
        decoder_padding_mask = invert_mask(decoder_padding_mask)
    causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
        dtype=causal_mask_dtype, device=decoder_input_ids.device
    )
    return decoder_input_ids, decoder_padding_mask, causal_mask
Sam Shleifer's avatar
Sam Shleifer committed
158
159
160
161
162
163
164
165
166
167
168
169


class PretrainedBartModel(PreTrainedModel):
    config_class = BartConfig
    base_model_prefix = "model"

    def _init_weights(self, module):
        std = self.config.init_std
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
170
171
172
        elif isinstance(module, SinusoidalPositionalEmbedding):
            pass
        elif isinstance(module, nn.Embedding):
Sam Shleifer's avatar
Sam Shleifer committed
173
174
175
176
177
178
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    def dummy_inputs(self):
179
        pad_token = self.config.pad_token_id
180
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
Sam Shleifer's avatar
Sam Shleifer committed
181
182
183
184
185
186
187
188
189
190
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs


def _make_linear_from_emb(emb):
    vocab_size, emb_size = emb.weight.shape
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
191
    lin_layer.weight.data = emb.weight.data
Sam Shleifer's avatar
Sam Shleifer committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    return lin_layer


# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
    if shape_1 != shape2:
        raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))


def shift_tokens_right(input_ids, pad_token_id):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens


def make_padding_mask(input_ids, padding_idx=1):
    """True for pad tokens"""
    padding_mask = input_ids.eq(padding_idx)
    if not padding_mask.any():
        padding_mask = None
    return padding_mask


# Helper Modules


class EncoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = SelfAttention(
            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
        )
Sam Shleifer's avatar
Sam Shleifer committed
228
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
229
230
        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.dropout = config.dropout
231
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
232
233
234
235
236
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

237
    def forward(self, x, encoder_padding_mask, output_attentions=False):
Sam Shleifer's avatar
Sam Shleifer committed
238
239
240
241
242
243
244
245
246
247
248
249
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            for t_tgt, t_src is excluded (or masked out), =0 means it is
            included in attention

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
250
251
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
252
        x, attn_weights = self.self_attn(
253
            query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
254
        )
Sam Shleifer's avatar
Sam Shleifer committed
255
256
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
257
258
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
259
260

        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
261
262
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
263
264
265
266
267
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
268
269
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        return x, attn_weights


class BartEncoder(nn.Module):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
    is a :class:`EncoderLayer`.

    Args:
        config: BartConfig
    """

    def __init__(self, config: BartConfig, embed_tokens):
        super().__init__()

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = embed_tokens.embedding_dim
Sam Shleifer's avatar
Sam Shleifer committed
289
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
290
291
292
293
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        self.embed_tokens = embed_tokens
294
295
296
297
298
299
        if config.static_position_embeddings:
            self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, embed_dim, self.padding_idx
            )
        else:
            self.embed_positions = LearnedPositionalEmbedding(
300
                config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
301
            )
Sam Shleifer's avatar
Sam Shleifer committed
302
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
303
        self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
Sam Shleifer's avatar
Sam Shleifer committed
304
305
        # mbart has one extra layer_norm
        self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
Sam Shleifer's avatar
Sam Shleifer committed
306

307
    def forward(
308
        self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
309
    ):
Sam Shleifer's avatar
Sam Shleifer committed
310
311
312
313
314
315
        """
        Args:
            input_ids (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            attention_mask (torch.LongTensor): indicating which indices are padding tokens.
        Returns:
316
            BaseModelOutput or Tuple comprised of:
Sam Shleifer's avatar
Sam Shleifer committed
317
318
                - **x** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
319
                - **encoder_states** (tuple(torch.FloatTensor)): all intermediate
Sam Shleifer's avatar
Sam Shleifer committed
320
                  hidden states of shape `(src_len, batch, embed_dim)`.
Joseph Liu's avatar
Joseph Liu committed
321
                  Only populated if *output_hidden_states:* is True.
322
                - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
Sam Shleifer's avatar
Sam Shleifer committed
323
324
                During training might not be of length n_layers because of layer dropout.
        """
325
326
        # check attention mask and invert
        if attention_mask is not None:
327
            attention_mask = invert_mask(attention_mask)
328

Sam Shleifer's avatar
Sam Shleifer committed
329
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
330
331
332
333
334
335
336
337
        embed_pos = self.embed_positions(input_ids)
        x = inputs_embeds + embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

338
339
        encoder_states = [] if output_hidden_states else None
        all_attentions = () if output_attentions else None
Sam Shleifer's avatar
Sam Shleifer committed
340
        for encoder_layer in self.layers:
Joseph Liu's avatar
Joseph Liu committed
341
            if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
342
343
344
345
346
347
                encoder_states.append(x)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                attn = None
            else:
348
                x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
Sam Shleifer's avatar
Sam Shleifer committed
349

350
            if output_attentions:
351
                all_attentions = all_attentions + (attn,)
Sam Shleifer's avatar
Sam Shleifer committed
352

Sam Shleifer's avatar
Sam Shleifer committed
353
354
        if self.layer_norm:
            x = self.layer_norm(x)
Joseph Liu's avatar
Joseph Liu committed
355
        if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
356
            encoder_states.append(x)
357
358
            # T x B x C -> B x T x C
            encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
Sam Shleifer's avatar
Sam Shleifer committed
359

360
361
362
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

363
        if not return_dict:
364
365
            return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
Sam Shleifer's avatar
Sam Shleifer committed
366
367
368
369
370
371
372
373
374
375


class DecoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = SelfAttention(
            embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
        )
        self.dropout = config.dropout
376
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
377
        self.activation_dropout = config.activation_dropout
Sam Shleifer's avatar
Sam Shleifer committed
378
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.encoder_attn = SelfAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(
393
394
395
396
397
398
399
        self,
        x,
        encoder_hidden_states,
        encoder_attn_mask=None,
        layer_state=None,
        causal_mask=None,
        decoder_padding_mask=None,
400
        output_attentions=False,
Sam Shleifer's avatar
Sam Shleifer committed
401
402
403
    ):
        residual = x

Sam Shleifer's avatar
Sam Shleifer committed
404
405
        if layer_state is None:
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
406
407
408
409
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        # Self Attention

410
        x, self_attn_weights = self.self_attn(
411
412
            query=x,
            key=x,
Sam Shleifer's avatar
Sam Shleifer committed
413
            layer_state=layer_state,  # adds keys to layer state
414
415
            key_padding_mask=decoder_padding_mask,
            attn_mask=causal_mask,
416
            output_attentions=output_attentions,
417
        )
Sam Shleifer's avatar
Sam Shleifer committed
418
419
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
420
421
422
423
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        # Cross attention
Sam Shleifer's avatar
Sam Shleifer committed
424
425
        residual = x
        assert self.encoder_attn.cache_key != self.self_attn.cache_key
Sam Shleifer's avatar
Sam Shleifer committed
426
427
        if self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
428
        x, _ = self.encoder_attn(
Sam Shleifer's avatar
Sam Shleifer committed
429
            query=x,
430
            key=encoder_hidden_states,
Sam Shleifer's avatar
Sam Shleifer committed
431
            key_padding_mask=encoder_attn_mask,
Sam Shleifer's avatar
Sam Shleifer committed
432
            layer_state=layer_state,  # mutates layer state
Sam Shleifer's avatar
Sam Shleifer committed
433
434
435
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
436
437
        if not self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
438

Sam Shleifer's avatar
Sam Shleifer committed
439
        # Fully Connected
Sam Shleifer's avatar
Sam Shleifer committed
440
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
441
442
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
443
444
445
446
447
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
448
449
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
450
451
452
        return (
            x,
            self_attn_weights,
Sam Shleifer's avatar
Sam Shleifer committed
453
454
            layer_state,
        )  # just self_attn weights for now, following t5, layer_state = cache for decoding
Sam Shleifer's avatar
Sam Shleifer committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471


class BartDecoder(nn.Module):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer
    is a :class:`DecoderLayer`.
    Args:
        config: BartConfig
        embed_tokens (torch.nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
        super().__init__()
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = config.max_position_embeddings
Sam Shleifer's avatar
Sam Shleifer committed
472
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
473
        self.embed_tokens = embed_tokens
474
475
476
477
478
479
        if config.static_position_embeddings:
            self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, config.d_model, config.pad_token_id
            )
        else:
            self.embed_positions = LearnedPositionalEmbedding(
480
                config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
481
            )
Sam Shleifer's avatar
Sam Shleifer committed
482
483
484
        self.layers = nn.ModuleList(
            [DecoderLayer(config) for _ in range(config.decoder_layers)]
        )  # type: List[DecoderLayer]
485
        self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
Sam Shleifer's avatar
Sam Shleifer committed
486
        self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
Sam Shleifer's avatar
Sam Shleifer committed
487
488
489
490
491
492

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        encoder_padding_mask,
493
494
        decoder_padding_mask,
        decoder_causal_mask,
495
        decoder_past_key_values=None,
496
        use_cache=False,
497
        output_attentions=False,
Joseph Liu's avatar
Joseph Liu committed
498
        output_hidden_states=False,
499
        return_dict=False,
500
        **unused,
Sam Shleifer's avatar
Sam Shleifer committed
501
502
503
504
505
506
507
508
509
510
511
    ):
        """
        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            input_ids (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_hidden_states: output from the encoder, used for
                encoder-side attention
            encoder_padding_mask: for ignoring pad tokens
512
            decoder_past_key_values (dict or None): dictionary used for storing state during generation
Sam Shleifer's avatar
Sam Shleifer committed
513
514

        Returns:
515
            BaseModelOutputWithPast or tuple:
Sam Shleifer's avatar
Sam Shleifer committed
516
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
517
                - the cache
Sam Shleifer's avatar
Sam Shleifer committed
518
519
520
                - hidden states
                - attentions
        """
521
522
523
524
525
526
527
        if "decoder_cached_states" in unused:
            warnings.warn(
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
                FutureWarning,
            )
            decoder_past_key_values = unused.pop("decoder_cached_states")

528
529
        # check attention mask and invert
        if encoder_padding_mask is not None:
530
            encoder_padding_mask = invert_mask(encoder_padding_mask)
531

Sam Shleifer's avatar
Sam Shleifer committed
532
        # embed positions
533
        positions = self.embed_positions(input_ids, use_cache=use_cache)
Sam Shleifer's avatar
Sam Shleifer committed
534

535
        if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
536
537
            input_ids = input_ids[:, -1:]
            positions = positions[:, -1:]  # happens after we embed them
538
            # assert input_ids.ne(self.padding_idx).any()
Sam Shleifer's avatar
Sam Shleifer committed
539

Sam Shleifer's avatar
Sam Shleifer committed
540
        x = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
541
        x += positions
Sam Shleifer's avatar
Sam Shleifer committed
542
543
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
544
545
546
547
548

        # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
        x = x.transpose(0, 1)
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

Sam Shleifer's avatar
Sam Shleifer committed
549
        # decoder layers
550
551
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
Sam Shleifer's avatar
Sam Shleifer committed
552
        next_decoder_cache = []
Sam Shleifer's avatar
Sam Shleifer committed
553
        for idx, decoder_layer in enumerate(self.layers):
Sam Shleifer's avatar
Sam Shleifer committed
554
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Joseph Liu's avatar
Joseph Liu committed
555
            if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
556
                all_hidden_states += (x,)
Sam Shleifer's avatar
Sam Shleifer committed
557
            dropout_probability = random.uniform(0, 1)
558
            if self.training and (dropout_probability < self.layerdrop):
Sam Shleifer's avatar
Sam Shleifer committed
559
                continue
Sam Shleifer's avatar
Sam Shleifer committed
560

561
            layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
Sam Shleifer's avatar
Sam Shleifer committed
562

563
            x, layer_self_attn, layer_past = decoder_layer(
564
565
566
567
568
569
                x,
                encoder_hidden_states,
                encoder_attn_mask=encoder_padding_mask,
                decoder_padding_mask=decoder_padding_mask,
                layer_state=layer_state,
                causal_mask=decoder_causal_mask,
570
                output_attentions=output_attentions,
Sam Shleifer's avatar
Sam Shleifer committed
571
            )
Sam Shleifer's avatar
Sam Shleifer committed
572

573
            if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
574
                next_decoder_cache.append(layer_past.copy())
Sam Shleifer's avatar
Sam Shleifer committed
575
576
577

            if self.layer_norm and (idx == len(self.layers) - 1):  # last layer of mbart
                x = self.layer_norm(x)
578
            if output_attentions:
Sam Shleifer's avatar
Sam Shleifer committed
579
580
                all_self_attns += (layer_self_attn,)

Sam Shleifer's avatar
Sam Shleifer committed
581
        # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
582
583
        if output_hidden_states:
            all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
Sam Shleifer's avatar
Sam Shleifer committed
584
        x = x.transpose(0, 1)
585
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
Sam Shleifer's avatar
Sam Shleifer committed
586

587
        if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
588
589
590
            next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
        else:
            next_cache = None
591

592
        if not return_dict:
593
594
595
596
            return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
        )
Sam Shleifer's avatar
Sam Shleifer committed
597
598


599
600
def _reorder_buffer(attn_cache, new_order):
    for k, input_buffer_k in attn_cache.items():
Sam Shleifer's avatar
Sam Shleifer committed
601
        if input_buffer_k is not None:
602
603
            attn_cache[k] = input_buffer_k.index_select(0, new_order)
    return attn_cache
Sam Shleifer's avatar
Sam Shleifer committed
604
605
606


class SelfAttention(nn.Module):
607
    """Multi-headed attention from 'Attention Is All You Need' paper"""
Sam Shleifer's avatar
Sam Shleifer committed
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        encoder_decoder_attention=False,  # otherwise self_attention
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.encoder_decoder_attention = encoder_decoder_attention
626
627
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Sam Shleifer's avatar
Sam Shleifer committed
628
629
630
631
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"

632
633
    def _shape(self, tensor, seq_len, bsz):
        return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
Sam Shleifer's avatar
Sam Shleifer committed
634
635
636
637
638
639

    def forward(
        self,
        query,
        key: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
640
        layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
Sam Shleifer's avatar
Sam Shleifer committed
641
        attn_mask: Optional[Tensor] = None,
642
        output_attentions=False,
Sam Shleifer's avatar
Sam Shleifer committed
643
    ) -> Tuple[Tensor, Optional[Tensor]]:
644
        """Input shape: Time(SeqLen) x Batch x Channel"""
645
        static_kv: bool = self.encoder_decoder_attention
Sam Shleifer's avatar
Sam Shleifer committed
646
647
648
649
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        # get here for encoder decoder cause of static_kv
650
        if layer_state is not None:  # reuse k,v and encoder_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
651
            saved_state = layer_state.get(self.cache_key, {})
652
            if "prev_key" in saved_state and static_kv:
Sam Shleifer's avatar
Sam Shleifer committed
653
                # previous time steps are cached - no need to recompute key and value if they are static
654
                key = None
Sam Shleifer's avatar
Sam Shleifer committed
655
656
        else:
            saved_state = None
Sam Shleifer's avatar
Sam Shleifer committed
657
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
658
659

        q = self.q_proj(query) * self.scaling
660
        if static_kv:
Sam Shleifer's avatar
Sam Shleifer committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
            if key is None:
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:
            k = self.k_proj(query)
            v = self.v_proj(query)

        q = self._shape(q, tgt_len, bsz)
        if k is not None:
            k = self._shape(k, -1, bsz)
        if v is not None:
            v = self._shape(v, -1, bsz)

        if saved_state is not None:
Sam Shleifer's avatar
Sam Shleifer committed
677
678
679
680
681
682
683
684
685
            k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)

        # Update cache
        layer_state[self.cache_key] = {
            "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_key_padding_mask": key_padding_mask if not static_kv else None,
        }

Sam Shleifer's avatar
Sam Shleifer committed
686
687
688
689
690
691
692
693
694
695
696
697
        assert k is not None
        src_len = k.size(1)
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)

        if attn_mask is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None
Sam Shleifer's avatar
Sam Shleifer committed
698
        assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
Sam Shleifer's avatar
Sam Shleifer committed
699
700
701

        if key_padding_mask is not None:  # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
702
            reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
Sam Shleifer's avatar
Sam Shleifer committed
703
704
            attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
sshleifer's avatar
sshleifer committed
705
706
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
sshleifer's avatar
sshleifer committed
707

Sam Shleifer's avatar
Sam Shleifer committed
708
709
710
711
712
        assert v is not None
        attn_output = torch.bmm(attn_probs, v)
        assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)
713
        if output_attentions:
714
715
716
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
        else:
            attn_weights = None
Sam Shleifer's avatar
Sam Shleifer committed
717
718
        return attn_output, attn_weights

Sam Shleifer's avatar
Sam Shleifer committed
719
    def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
Sam Shleifer's avatar
Sam Shleifer committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
        # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
        if "prev_key" in saved_state:
            _prev_key = saved_state["prev_key"]
            assert _prev_key is not None
            prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                k = prev_key
            else:
                assert k is not None
                k = torch.cat([prev_key, k], dim=1)
        if "prev_value" in saved_state:
            _prev_value = saved_state["prev_value"]
            assert _prev_value is not None
            prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                v = prev_value
            else:
                assert v is not None
                v = torch.cat([prev_value, v], dim=1)
        assert k is not None and v is not None
740
        prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
741
742
743
744
745
        if prev_key_padding_mask is not None:
            if static_kv:
                new_key_padding_mask = prev_key_padding_mask
            else:
                new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
Sam Shleifer's avatar
Sam Shleifer committed
746
        else:
747
748
            new_key_padding_mask = key_padding_mask
        return k, v, new_key_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780


class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    # This can trivially be shared with RobertaClassificationHead

    def __init__(
        self, input_dim, inner_dim, num_classes, pooler_dropout,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    Padding ids are ignored by either offsetting based on padding_idx
    or by setting padding_idx to None and ensuring that the appropriate
    position ids are passed to the forward function.
    """

781
782
783
784
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models dont have this hack
        self.offset = offset
Sam Shleifer's avatar
Sam Shleifer committed
785
        assert padding_idx is not None
786
        num_embeddings += offset
Sam Shleifer's avatar
Sam Shleifer committed
787
788
        super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)

789
    def forward(self, input_ids, use_cache=False):
Sam Shleifer's avatar
Sam Shleifer committed
790
        """Input is expected to be of size [bsz x seqlen]."""
791
792
793
        bsz, seq_len = input_ids.shape[:2]
        if use_cache:
            positions = input_ids.data.new(1, 1).fill_(seq_len - 1)  # called before slicing
Sam Shleifer's avatar
Sam Shleifer committed
794
        else:
795
796
797
            # starts at 0, ends at 1-seq_len
            positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions + self.offset)
Sam Shleifer's avatar
Sam Shleifer committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
    if torch.cuda.is_available():
        try:
            from apex.normalization import FusedLayerNorm

            return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
        except ImportError:
            pass
    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


def fill_with_neg_inf(t):
    """FP16-compatible function that fills a input_ids with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


# Public API
817
818
def _get_shape(t):
    return getattr(t, "shape", None)
Sam Shleifer's avatar
Sam Shleifer committed
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836


@add_start_docstrings(
    "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
)
class BartModel(PretrainedBartModel):
    def __init__(self, config: BartConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = BartEncoder(config, self.shared)
        self.decoder = BartDecoder(config, self.shared)

        self.init_weights()

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
837
838
839
840
841
842
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=BaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
Sam Shleifer's avatar
Sam Shleifer committed
843
844
845
846
847
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
848
        encoder_outputs: Optional[Tuple] = None,
Sam Shleifer's avatar
Sam Shleifer committed
849
        decoder_attention_mask=None,
850
        decoder_past_key_values=None,
851
        use_cache=None,
852
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
853
        output_hidden_states=None,
854
        return_dict=None,
855
        **kwargs,
Sam Shleifer's avatar
Sam Shleifer committed
856
    ):
857
858
859
860

        if decoder_input_ids is None:
            use_cache = False

861
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
862
863
864
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
865
        use_cache = use_cache if use_cache is not None else self.config.use_cache
866
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sam Shleifer's avatar
Sam Shleifer committed
867
868

        # make masks if user doesn't supply
869
        if not use_cache:
870
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
871
872
873
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
874
875
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
Sam Shleifer's avatar
Sam Shleifer committed
876
            )
877
878
879
        else:
            decoder_padding_mask, causal_mask = None, None

Sam Shleifer's avatar
Sam Shleifer committed
880
        assert decoder_input_ids is not None
881

Sam Shleifer's avatar
Sam Shleifer committed
882
        if encoder_outputs is None:
883
            encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
884
885
886
887
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
888
                return_dict=return_dict,
889
            )
890
891
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
892
893
894
895
896
897
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

Sam Shleifer's avatar
Sam Shleifer committed
898
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
899
        decoder_outputs = self.decoder(
Sam Shleifer's avatar
Sam Shleifer committed
900
901
902
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
903
904
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
905
            decoder_past_key_values=decoder_past_key_values,
906
            use_cache=use_cache,
907
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
908
            output_hidden_states=output_hidden_states,
909
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
910
        )
911

912
        if not return_dict:
913
914
915
916
917
918
919
920
921
922
923
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Sam Shleifer's avatar
Sam Shleifer committed
924
925
926
927
928
929

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
930
931
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared
Sam Shleifer's avatar
Sam Shleifer committed
932
933

    def get_output_embeddings(self):
Sam Shleifer's avatar
Sam Shleifer committed
934
        return _make_linear_from_emb(self.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
935
936
937


@add_start_docstrings(
938
    "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
Sam Shleifer's avatar
Sam Shleifer committed
939
)
940
class BartForConditionalGeneration(PretrainedBartModel):
Sam Shleifer's avatar
Sam Shleifer committed
941
    base_model_prefix = "model"
942
    authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
Sam Shleifer's avatar
Sam Shleifer committed
943
944
945

    def __init__(self, config: BartConfig):
        super().__init__(config)
Sam Shleifer's avatar
Sam Shleifer committed
946
947
        base_model = BartModel(config)
        self.model = base_model
948
949
950
951
952
953
954
955
956
957
958
959
960
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))

    def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        old_num_tokens = self.model.shared.num_embeddings
        new_embeddings = super().resize_token_embeddings(new_num_tokens)
        self.model.shared = new_embeddings
        self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
961
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
962
963
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)
Sam Shleifer's avatar
Sam Shleifer committed
964

Sam Shleifer's avatar
Sam Shleifer committed
965
    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
966
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
967
    @add_end_docstrings(BART_GENERATION_EXAMPLE)
Sam Shleifer's avatar
Sam Shleifer committed
968
969
970
971
972
973
974
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
975
        decoder_past_key_values=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
976
        labels=None,
977
        use_cache=None,
978
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
979
        output_hidden_states=None,
980
        return_dict=None,
981
        **unused,
Sam Shleifer's avatar
Sam Shleifer committed
982
983
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
984
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Sam Shleifer's avatar
Sam Shleifer committed
985
986
987
            Labels for computing the masked language modeling loss.
            Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
988
            with labels in ``[0, ..., config.vocab_size]``.
Sam Shleifer's avatar
Sam Shleifer committed
989
990
991

    Returns:

992
    Conditional generation example::
Sam Shleifer's avatar
Sam Shleifer committed
993

994
995
            # Mask filling only works for bart-large
            from transformers import BartTokenizer, BartForConditionalGeneration
996
            tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
997
            TXT = "My friends are <mask> but they eat too many carbs."
998
999
1000

            model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
            input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
1001
            logits = model(input_ids).logits
1002

1003
1004
1005
            masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
            probs = logits[0, masked_index].softmax(dim=0)
            values, predictions = probs.topk(5)
1006

1007
1008
            tokenizer.decode(predictions).split()
            # ['good', 'great', 'all', 'really', 'very']
Sam Shleifer's avatar
Sam Shleifer committed
1009
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1010
1011
1012
        if "lm_labels" in unused:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1013
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
1014
1015
            )
            labels = unused.pop("lm_labels")
1016
1017
1018
1019
1020
1021
        if "decoder_cached_states" in unused:
            warnings.warn(
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
                FutureWarning,
            )
            decoder_past_key_values = unused.pop("decoder_cached_states")
1022
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
1023

1024
1025
1026
        if labels is not None:
            use_cache = False

1027
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
1028
1029
1030
1031
1032
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
1033
            decoder_past_key_values=decoder_past_key_values,
1034
            use_cache=use_cache,
1035
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1036
            output_hidden_states=output_hidden_states,
1037
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
1038
        )
1039
        lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
1040
1041

        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1042
        if labels is not None:
Sam Shleifer's avatar
Sam Shleifer committed
1043
            loss_fct = nn.CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
1044
1045
            # TODO(SS): do we need to ignore pad tokens in labels?
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
Sam Shleifer's avatar
Sam Shleifer committed
1046

1047
        if not return_dict:
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            decoder_past_key_values=outputs.decoder_past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Sam Shleifer's avatar
Sam Shleifer committed
1061

1062
    def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
1063
1064
        assert past is not None, "past has to be defined for encoder_outputs"

1065
        encoder_outputs, decoder_past_key_values = past
Patrick von Platen's avatar
Patrick von Platen committed
1066
        return {
1067
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
Patrick von Platen's avatar
Patrick von Platen committed
1068
            "encoder_outputs": encoder_outputs,
1069
            "decoder_past_key_values": decoder_past_key_values,
Patrick von Platen's avatar
Patrick von Platen committed
1070
            "decoder_input_ids": decoder_input_ids,
1071
            "attention_mask": attention_mask,
1072
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
Sam Shleifer's avatar
Sam Shleifer committed
1073
1074
        }

1075
    def adjust_logits_during_generation(self, logits, cur_len, max_length):
patrickvonplaten's avatar
patrickvonplaten committed
1076
        if cur_len == 1:
1077
            self._force_token_ids_generation(logits, self.config.bos_token_id)
1078
        if cur_len == max_length - 1 and self.config.eos_token_id is not None:
1079
1080
            self._force_token_ids_generation(logits, self.config.eos_token_id)
        return logits
patrickvonplaten's avatar
patrickvonplaten committed
1081

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
    def _force_token_ids_generation(self, scores, token_ids) -> None:
        """force one of token_ids to be generated by setting prob of all other tokens to 0"""
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        all_but_token_ids_mask = torch.tensor(
            [x for x in range(self.config.vocab_size) if x not in token_ids],
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
        assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
        scores[:, all_but_token_ids_mask] = -float("inf")

Sam Shleifer's avatar
Sam Shleifer committed
1094
1095
    @staticmethod
    def _reorder_cache(past, beam_idx):
1096
        ((enc_out, enc_mask), decoder_past_key_values) = past
Sam Shleifer's avatar
Sam Shleifer committed
1097
        reordered_past = []
1098
        for layer_past in decoder_past_key_values:
Sam Shleifer's avatar
Sam Shleifer committed
1099
1100
            # get the correct batch idx from decoder layer's batch dim for cross and self-attn
            layer_past_new = {
1101
                attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
Sam Shleifer's avatar
Sam Shleifer committed
1102
1103
            }
            reordered_past.append(layer_past_new)
1104
1105

        new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
Sam Shleifer's avatar
Sam Shleifer committed
1106
1107
1108
1109
        new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)

        past = ((new_enc_out, new_enc_mask), reordered_past)
        return past
Sam Shleifer's avatar
Sam Shleifer committed
1110

1111
1112
1113
    def get_encoder(self):
        return self.model.encoder

Sam Shleifer's avatar
Sam Shleifer committed
1114
    def get_output_embeddings(self):
1115
        return _make_linear_from_emb(self.model.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132


@add_start_docstrings(
    """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
    BART_START_DOCSTRING,
)
class BartForSequenceClassification(PretrainedBartModel):
    def __init__(self, config: BartConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.model = BartModel(config)
        self.classification_head = BartClassificationHead(
            config.d_model, config.d_model, config.num_labels, config.classif_dropout,
        )
        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1133
1134
1135
1136
1137
1138
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=Seq2SeqSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Sam Shleifer's avatar
Sam Shleifer committed
1139
1140
1141
1142
1143
1144
1145
1146
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        labels=None,
1147
        use_cache=None,
1148
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1149
        output_hidden_states=None,
1150
        return_dict=None,
Sam Shleifer's avatar
Sam Shleifer committed
1151
1152
1153
1154
1155
1156
1157
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
1158
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1159
1160
1161
        if labels is not None:
            use_cache = False

1162
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
1163
1164
1165
1166
1167
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
1168
            use_cache=use_cache,
1169
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1170
            output_hidden_states=output_hidden_states,
1171
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
1172
1173
        )
        x = outputs[0]  # last hidden state
1174
        eos_mask = input_ids.eq(self.config.eos_token_id)
Sam Shleifer's avatar
Sam Shleifer committed
1175
1176
1177
1178
        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
        logits = self.classification_head(sentence_representation)
1179
1180
1181

        loss = None
        if labels is not None:
1182
            loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
Sam Shleifer's avatar
Sam Shleifer committed
1183

1184
        if not return_dict:
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            decoder_past_key_values=outputs.decoder_past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
1198
1199


Suraj Patil's avatar
Suraj Patil committed
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
@add_start_docstrings(
    """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    BART_START_DOCSTRING,
)
class BartForQuestionAnswering(PretrainedBartModel):
    def __init__(self, config):
        super().__init__(config)

        config.num_labels = 2
        self.num_labels = config.num_labels

        self.model = BartModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.model._init_weights(self.qa_outputs)

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1218
1219
1220
1221
1222
1223
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=Seq2SeqQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Suraj Patil's avatar
Suraj Patil committed
1224
1225
1226
1227
1228
1229
1230
1231
1232
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        start_positions=None,
        end_positions=None,
1233
        use_cache=None,
Suraj Patil's avatar
Suraj Patil committed
1234
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1235
        output_hidden_states=None,
1236
        return_dict=None,
Suraj Patil's avatar
Suraj Patil committed
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        """
1248
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1249
1250
        if start_positions is not None and end_positions is not None:
            use_cache = False
Suraj Patil's avatar
Suraj Patil committed
1251
1252
1253
1254
1255
1256
1257

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
1258
            use_cache=use_cache,
Suraj Patil's avatar
Suraj Patil committed
1259
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1260
            output_hidden_states=output_hidden_states,
1261
            return_dict=return_dict,
Suraj Patil's avatar
Suraj Patil committed
1262
1263
1264
1265
1266
1267
1268
1269
1270
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1271
        total_loss = None
Suraj Patil's avatar
Suraj Patil committed
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1288
        if not return_dict:
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
            output = (start_logits, end_logits,) + outputs[1:]
            return ((total_loss,) + output) if total_loss is not None else output

        return Seq2SeqQuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            decoder_past_key_values=outputs.decoder_past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Suraj Patil's avatar
Suraj Patil committed
1303
1304


1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
class SinusoidalPositionalEmbedding(nn.Embedding):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions, embedding_dim, padding_idx=None):
        super().__init__(num_positions, embedding_dim)
        if embedding_dim % 2 != 0:
            raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter):
        """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
            The cos features are in the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))  # This line breaks for odd n_pos
        out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        out.requires_grad = False
        return out

    @torch.no_grad()
    def forward(self, input_ids, use_cache=False):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input_ids.shape[:2]
        if use_cache:
            positions = input_ids.data.new(1, 1).fill_(seq_len - 1)  # called before slicing
        else:
            # starts at 0, ends at 1-seq_len
            positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions)