modeling_bart.py 44.6 KB
Newer Older
Sam Shleifer's avatar
Sam Shleifer committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
Sam Shleifer's avatar
Sam Shleifer committed
17
import math
Sam Shleifer's avatar
Sam Shleifer committed
18
19
20
21
22
23
24
import random
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn

25
from .activations import ACT2FN
Sam Shleifer's avatar
Sam Shleifer committed
26
27
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
patrickvonplaten's avatar
patrickvonplaten committed
28
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
Sam Shleifer's avatar
Sam Shleifer committed
29
30
31
32
33
34
35
36


logger = logging.getLogger(__name__)


BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin",
    "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin",
Sam Shleifer's avatar
Sam Shleifer committed
37
    "bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/pytorch_model.bin",
38
    "bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/pytorch_model.bin",
Sam Shleifer's avatar
Sam Shleifer committed
39
    "mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/pytorch_model.bin",
Sam Shleifer's avatar
Sam Shleifer committed
40
41
42
43
44
45
46
47
48
49
50
51
}

BART_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matters related to general usage and behavior.

    Parameters:
        config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.

52
53
54
55
56
57
58
59
60
61
62
"""
BART_GENERATION_EXAMPLE = r"""
    Examples::

        from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
        # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
        model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
        tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
        # Generate Summary
63
        summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
64
65
        print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

Sam Shleifer's avatar
Sam Shleifer committed
66
67
68
69
70
71
72
73
74
75
76
77
"""

BART_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
               Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
            Padding will be ignored by default should you provide it.
            Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
        attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices in input_ids.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
78
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Patrick von Platen's avatar
Patrick von Platen committed
79
80
81
            Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
            `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
            Used in the cross-attention of the decoder.
Sam Shleifer's avatar
Sam Shleifer committed
82
83
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
            Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
84
85
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
            Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
Sam Shleifer's avatar
Sam Shleifer committed
86
87
88
            If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
            See diagram 1 in the paper for more info on the default strategy
"""
89
90
91
92
93


def invert_mask(attention_mask):
    assert attention_mask.dim() == 2
    return attention_mask.eq(0)
Sam Shleifer's avatar
Sam Shleifer committed
94
95
96


def _prepare_bart_decoder_inputs(
97
    config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
Sam Shleifer's avatar
Sam Shleifer committed
98
):
99
    """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
Sam Shleifer's avatar
Sam Shleifer committed
100
    none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
101
    Note: this is not called during generation
Sam Shleifer's avatar
Sam Shleifer committed
102
103
104
105
    """
    pad_token_id = config.pad_token_id
    if decoder_input_ids is None:
        decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
106
107
    bsz, tgt_len = decoder_input_ids.size()
    if decoder_padding_mask is None:
Sam Shleifer's avatar
Sam Shleifer committed
108
        decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
109
110
111
112
113
114
    else:
        decoder_padding_mask = invert_mask(decoder_padding_mask)
    causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
        dtype=causal_mask_dtype, device=decoder_input_ids.device
    )
    return decoder_input_ids, decoder_padding_mask, causal_mask
Sam Shleifer's avatar
Sam Shleifer committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134


class PretrainedBartModel(PreTrainedModel):
    config_class = BartConfig
    base_model_prefix = "model"
    pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP

    def _init_weights(self, module):
        std = self.config.init_std
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    def dummy_inputs(self):
135
        pad_token = self.config.pad_token_id
136
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
Sam Shleifer's avatar
Sam Shleifer committed
137
138
139
140
141
142
143
144
145
146
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs


def _make_linear_from_emb(emb):
    vocab_size, emb_size = emb.weight.shape
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
147
    lin_layer.weight.data = emb.weight.data
Sam Shleifer's avatar
Sam Shleifer committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    return lin_layer


# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
    if shape_1 != shape2:
        raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))


def shift_tokens_right(input_ids, pad_token_id):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens


def make_padding_mask(input_ids, padding_idx=1):
    """True for pad tokens"""
    padding_mask = input_ids.eq(padding_idx)
    if not padding_mask.any():
        padding_mask = None
    return padding_mask


# Helper Modules


class EncoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
        self.output_attentions = config.output_attentions
        self.self_attn = SelfAttention(
            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
        )
Sam Shleifer's avatar
Sam Shleifer committed
185
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
186
187
        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.dropout = config.dropout
188
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(self, x, encoder_padding_mask):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            for t_tgt, t_src is excluded (or masked out), =0 means it is
            included in attention

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
207
208
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
209
210
211
        x, attn_weights = self.self_attn(
            query=x, key=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions
        )
Sam Shleifer's avatar
Sam Shleifer committed
212
213
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
214
215
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
216
217

        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
218
219
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
220
221
222
223
224
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
225
226
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        return x, attn_weights


class BartEncoder(nn.Module):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
    is a :class:`EncoderLayer`.

    Args:
        config: BartConfig
    """

    def __init__(self, config: BartConfig, embed_tokens):
        super().__init__()

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        embed_dim = embed_tokens.embedding_dim
Sam Shleifer's avatar
Sam Shleifer committed
248
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
249
250
251
252
253
254
255
256
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        self.embed_tokens = embed_tokens

        self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,)
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layernorm_embedding = LayerNorm(embed_dim)
Sam Shleifer's avatar
Sam Shleifer committed
257
258
        # mbart has one extra layer_norm
        self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
Sam Shleifer's avatar
Sam Shleifer committed
259
260

    def forward(
261
        self, input_ids, attention_mask=None,
Sam Shleifer's avatar
Sam Shleifer committed
262
263
264
265
266
267
268
    ):
        """
        Args:
            input_ids (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            attention_mask (torch.LongTensor): indicating which indices are padding tokens.
        Returns:
269
            Tuple comprised of:
Sam Shleifer's avatar
Sam Shleifer committed
270
271
272
273
                - **x** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
274
                  Only populated if *self.output_hidden_states:* is True.
Sam Shleifer's avatar
Sam Shleifer committed
275
276
277
                - **all_attentions** (List[Tensor]): Attention weights for each layer.
                During training might not be of length n_layers because of layer dropout.
        """
278
279
        # check attention mask and invert
        if attention_mask is not None:
280
            attention_mask = invert_mask(attention_mask)
281

Sam Shleifer's avatar
Sam Shleifer committed
282
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        embed_pos = self.embed_positions(input_ids)
        x = inputs_embeds + embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        encoder_states, all_attentions = [], []
        for encoder_layer in self.layers:
            if self.output_hidden_states:
                encoder_states.append(x)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                attn = None
            else:
300
                x, attn = encoder_layer(x, attention_mask)
Sam Shleifer's avatar
Sam Shleifer committed
301
302
303
304

            if self.output_attentions:
                all_attentions.append(attn)

Sam Shleifer's avatar
Sam Shleifer committed
305
306
        if self.layer_norm:
            x = self.layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
307
308
309
        if self.output_hidden_states:
            encoder_states.append(x)

310
        # T x B x C -> B x T x C
Sam Shleifer's avatar
Sam Shleifer committed
311
        encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
312
313
        x = x.transpose(0, 1)

Sam Shleifer's avatar
Sam Shleifer committed
314
315
316
317
318
319
320
        return x, encoder_states, all_attentions


class DecoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
321
        self.output_attentions = config.output_attentions
Sam Shleifer's avatar
Sam Shleifer committed
322
323
324
325
        self.self_attn = SelfAttention(
            embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
        )
        self.dropout = config.dropout
326
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
327
        self.activation_dropout = config.activation_dropout
Sam Shleifer's avatar
Sam Shleifer committed
328
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.encoder_attn = SelfAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(
343
344
345
346
347
348
349
        self,
        x,
        encoder_hidden_states,
        encoder_attn_mask=None,
        layer_state=None,
        causal_mask=None,
        decoder_padding_mask=None,
Sam Shleifer's avatar
Sam Shleifer committed
350
351
352
    ):
        residual = x

Sam Shleifer's avatar
Sam Shleifer committed
353
354
        if layer_state is None:
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
355
356
357
358
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        # Self Attention

359
        x, self_attn_weights = self.self_attn(
360
361
            query=x,
            key=x,
Sam Shleifer's avatar
Sam Shleifer committed
362
            layer_state=layer_state,  # adds keys to layer state
363
364
365
            key_padding_mask=decoder_padding_mask,
            attn_mask=causal_mask,
            need_weights=self.output_attentions,
366
        )
Sam Shleifer's avatar
Sam Shleifer committed
367
368
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
369
370
371
372
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        # Cross attention
Sam Shleifer's avatar
Sam Shleifer committed
373
374
        residual = x
        assert self.encoder_attn.cache_key != self.self_attn.cache_key
Sam Shleifer's avatar
Sam Shleifer committed
375
376
        if self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
377
        x, _ = self.encoder_attn(
Sam Shleifer's avatar
Sam Shleifer committed
378
            query=x,
379
            key=encoder_hidden_states,
Sam Shleifer's avatar
Sam Shleifer committed
380
            key_padding_mask=encoder_attn_mask,
Sam Shleifer's avatar
Sam Shleifer committed
381
            layer_state=layer_state,  # mutates layer state
Sam Shleifer's avatar
Sam Shleifer committed
382
383
384
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
385
386
        if not self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
387

Sam Shleifer's avatar
Sam Shleifer committed
388
        # Fully Connected
Sam Shleifer's avatar
Sam Shleifer committed
389
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
390
391
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
392
393
394
395
396
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
397
398
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
399
400
401
        return (
            x,
            self_attn_weights,
Sam Shleifer's avatar
Sam Shleifer committed
402
403
            layer_state,
        )  # just self_attn weights for now, following t5, layer_state = cache for decoding
Sam Shleifer's avatar
Sam Shleifer committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422


class BartDecoder(nn.Module):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer
    is a :class:`DecoderLayer`.
    Args:
        config: BartConfig
        embed_tokens (torch.nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
        super().__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = config.max_position_embeddings
Sam Shleifer's avatar
Sam Shleifer committed
423
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
424
425
426
427
428
429
430
431
        self.embed_tokens = embed_tokens
        self.embed_positions = LearnedPositionalEmbedding(
            config.max_position_embeddings, config.d_model, self.padding_idx,
        )
        self.layers = nn.ModuleList(
            [DecoderLayer(config) for _ in range(config.decoder_layers)]
        )  # type: List[DecoderLayer]
        self.layernorm_embedding = LayerNorm(config.d_model)
Sam Shleifer's avatar
Sam Shleifer committed
432
        self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
Sam Shleifer's avatar
Sam Shleifer committed
433
434
435
436
437
438

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        encoder_padding_mask,
439
440
        decoder_padding_mask,
        decoder_causal_mask,
Sam Shleifer's avatar
Sam Shleifer committed
441
        decoder_cached_states=None,
442
        use_cache=False,
Sam Shleifer's avatar
Sam Shleifer committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        **unused
    ):
        """
        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            input_ids (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_hidden_states: output from the encoder, used for
                encoder-side attention
            encoder_padding_mask: for ignoring pad tokens
            decoder_cached_states (dict or None): dictionary used for storing state during generation

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - hidden states
                - attentions
        """
463
464
        # check attention mask and invert
        if encoder_padding_mask is not None:
465
            encoder_padding_mask = invert_mask(encoder_padding_mask)
466

Sam Shleifer's avatar
Sam Shleifer committed
467
        # embed positions
468
        positions = self.embed_positions(input_ids, use_cache=use_cache)
Sam Shleifer's avatar
Sam Shleifer committed
469

470
        if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
471
472
473
474
            input_ids = input_ids[:, -1:]
            positions = positions[:, -1:]  # happens after we embed them
            assert input_ids.ne(self.padding_idx).any()

Sam Shleifer's avatar
Sam Shleifer committed
475
        x = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
476
        x += positions
Sam Shleifer's avatar
Sam Shleifer committed
477
478
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
479
480
481
482
483

        # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
        x = x.transpose(0, 1)
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

Sam Shleifer's avatar
Sam Shleifer committed
484
485
486
487
        # decoder layers
        all_hidden_states = ()
        all_self_attns = ()
        next_decoder_cache = []
Sam Shleifer's avatar
Sam Shleifer committed
488
        for idx, decoder_layer in enumerate(self.layers):
Sam Shleifer's avatar
Sam Shleifer committed
489
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Sam Shleifer's avatar
Sam Shleifer committed
490
491
            if self.output_hidden_states:
                all_hidden_states += (x,)
Sam Shleifer's avatar
Sam Shleifer committed
492
            dropout_probability = random.uniform(0, 1)
493
            if self.training and (dropout_probability < self.layerdrop):
Sam Shleifer's avatar
Sam Shleifer committed
494
                continue
Sam Shleifer's avatar
Sam Shleifer committed
495

Sam Shleifer's avatar
Sam Shleifer committed
496
497
            layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None

498
            x, layer_self_attn, layer_past = decoder_layer(
499
500
501
502
503
504
                x,
                encoder_hidden_states,
                encoder_attn_mask=encoder_padding_mask,
                decoder_padding_mask=decoder_padding_mask,
                layer_state=layer_state,
                causal_mask=decoder_causal_mask,
Sam Shleifer's avatar
Sam Shleifer committed
505
            )
Sam Shleifer's avatar
Sam Shleifer committed
506

507
            if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
508
                next_decoder_cache.append(layer_past.copy())
Sam Shleifer's avatar
Sam Shleifer committed
509
510
511

            if self.layer_norm and (idx == len(self.layers) - 1):  # last layer of mbart
                x = self.layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
512
513
514
            if self.output_attentions:
                all_self_attns += (layer_self_attn,)

Sam Shleifer's avatar
Sam Shleifer committed
515
        # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
Sam Shleifer's avatar
Sam Shleifer committed
516
517
        all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
        x = x.transpose(0, 1)
518
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
Sam Shleifer's avatar
Sam Shleifer committed
519

520
        if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
521
522
523
524
525
526
            next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
        else:
            next_cache = None
        return x, next_cache, all_hidden_states, list(all_self_attns)


527
528
def _reorder_buffer(attn_cache, new_order):
    for k, input_buffer_k in attn_cache.items():
Sam Shleifer's avatar
Sam Shleifer committed
529
        if input_buffer_k is not None:
530
531
            attn_cache[k] = input_buffer_k.index_select(0, new_order)
    return attn_cache
Sam Shleifer's avatar
Sam Shleifer committed
532
533
534


class SelfAttention(nn.Module):
535
    """Multi-headed attention from 'Attention Is All You Need' paper"""
Sam Shleifer's avatar
Sam Shleifer committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553

    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        encoder_decoder_attention=False,  # otherwise self_attention
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.encoder_decoder_attention = encoder_decoder_attention
554
555
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Sam Shleifer's avatar
Sam Shleifer committed
556
557
558
559
560
561
562
563
564
565
566
567
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"

    def _shape(self, tensor, dim_0, bsz):
        return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1)

    def forward(
        self,
        query,
        key: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
568
        layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
Sam Shleifer's avatar
Sam Shleifer committed
569
        attn_mask: Optional[Tensor] = None,
570
        need_weights=False,
Sam Shleifer's avatar
Sam Shleifer committed
571
    ) -> Tuple[Tensor, Optional[Tensor]]:
572
573
        """Input shape: Time(SeqLen) x Batch x Channel"""
        static_kv = self.encoder_decoder_attention  # type: bool
Sam Shleifer's avatar
Sam Shleifer committed
574
575
576
577
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        # get here for encoder decoder cause of static_kv
578
        if layer_state is not None:  # reuse k,v and encoder_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
579
            saved_state = layer_state.get(self.cache_key, {})
Sam Shleifer's avatar
Sam Shleifer committed
580
581
582
            if "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute key and value if they are static
                if static_kv:
583
                    key = None
Sam Shleifer's avatar
Sam Shleifer committed
584
585
        else:
            saved_state = None
Sam Shleifer's avatar
Sam Shleifer committed
586
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
587
588

        q = self.q_proj(query) * self.scaling
589
        if static_kv:
Sam Shleifer's avatar
Sam Shleifer committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
            if key is None:
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:
            k = self.k_proj(query)
            v = self.v_proj(query)

        q = self._shape(q, tgt_len, bsz)
        if k is not None:
            k = self._shape(k, -1, bsz)
        if v is not None:
            v = self._shape(v, -1, bsz)

        if saved_state is not None:
Sam Shleifer's avatar
Sam Shleifer committed
606
607
608
609
610
611
612
613
614
            k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)

        # Update cache
        layer_state[self.cache_key] = {
            "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_key_padding_mask": key_padding_mask if not static_kv else None,
        }

Sam Shleifer's avatar
Sam Shleifer committed
615
616
617
618
619
620
621
622
623
624
625
626
        assert k is not None
        src_len = k.size(1)
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)

        if attn_mask is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None
Sam Shleifer's avatar
Sam Shleifer committed
627
        assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
Sam Shleifer's avatar
Sam Shleifer committed
628
629
630

        if key_padding_mask is not None:  # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
631
            reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
Sam Shleifer's avatar
Sam Shleifer committed
632
633
            attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
sshleifer's avatar
sshleifer committed
634
635
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
sshleifer's avatar
sshleifer committed
636

Sam Shleifer's avatar
Sam Shleifer committed
637
638
639
640
641
        assert v is not None
        attn_output = torch.bmm(attn_probs, v)
        assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)
642
643
644
645
        if need_weights:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
        else:
            attn_weights = None
Sam Shleifer's avatar
Sam Shleifer committed
646
647
        return attn_output, attn_weights

Sam Shleifer's avatar
Sam Shleifer committed
648
    def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
Sam Shleifer's avatar
Sam Shleifer committed
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
        if "prev_key" in saved_state:
            _prev_key = saved_state["prev_key"]
            assert _prev_key is not None
            prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                k = prev_key
            else:
                assert k is not None
                k = torch.cat([prev_key, k], dim=1)
        if "prev_value" in saved_state:
            _prev_value = saved_state["prev_value"]
            assert _prev_value is not None
            prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                v = prev_value
            else:
                assert v is not None
                v = torch.cat([prev_value, v], dim=1)
        assert k is not None and v is not None
        prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None)  # type: Optional[Tensor]
        key_padding_mask = self._cat_prev_key_padding_mask(
            key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
        )
Sam Shleifer's avatar
Sam Shleifer committed
673
        return k, v, key_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
674
675
676
677
678
679
680
681
682
683

    @staticmethod
    def _cat_prev_key_padding_mask(
        key_padding_mask: Optional[Tensor],
        prev_key_padding_mask: Optional[Tensor],
        batch_size: int,
        src_len: int,
        static_kv: bool,
    ) -> Optional[Tensor]:
        # saved key padding masks have shape (bsz, seq_len)
684
685
686
687
688
689
        if prev_key_padding_mask is not None:
            if static_kv:
                new_key_padding_mask = prev_key_padding_mask
            else:
                new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)

Sam Shleifer's avatar
Sam Shleifer committed
690
        elif key_padding_mask is not None:
691
692
693
694
695
696
697
            filler = torch.zeros(
                batch_size,
                src_len - key_padding_mask.size(1),
                dtype=key_padding_mask.dtype,
                device=key_padding_mask.device,
            )
            new_key_padding_mask = torch.cat([filler, key_padding_mask], dim=1)
Sam Shleifer's avatar
Sam Shleifer committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
        else:
            new_key_padding_mask = prev_key_padding_mask
        return new_key_padding_mask


class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    # This can trivially be shared with RobertaClassificationHead

    def __init__(
        self, input_dim, inner_dim, num_classes, pooler_dropout,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    Padding ids are ignored by either offsetting based on padding_idx
    or by setting padding_idx to None and ensuring that the appropriate
    position ids are passed to the forward function.
    """

    def __init__(
        self, num_embeddings: int, embedding_dim: int, padding_idx: int,
    ):
        # if padding_idx is specified then offset the embedding ids by
        # this index and adjust num_embeddings appropriately
        assert padding_idx is not None
        num_embeddings += padding_idx + 1  # WHY?
        super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)

742
    def forward(self, input, use_cache=False):
Sam Shleifer's avatar
Sam Shleifer committed
743
        """Input is expected to be of size [bsz x seqlen]."""
744
        if use_cache:  # the position is our current step in the decoded sequence
Sam Shleifer's avatar
Sam Shleifer committed
745
746
747
748
            pos = int(self.padding_idx + input.size(1))
            positions = input.data.new(1, 1).fill_(pos)
        else:
            positions = create_position_ids_from_input_ids(input, self.padding_idx)
Sam Shleifer's avatar
Sam Shleifer committed
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
        return super().forward(positions)


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
    if torch.cuda.is_available():
        try:
            from apex.normalization import FusedLayerNorm

            return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
        except ImportError:
            pass
    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


def fill_with_neg_inf(t):
    """FP16-compatible function that fills a input_ids with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


def _filter_out_falsey_values(tup) -> Tuple:
    """Remove entries that are None or [] from an iterable."""
    return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)


# Public API
774
775
def _get_shape(t):
    return getattr(t, "shape", None)
Sam Shleifer's avatar
Sam Shleifer committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803


@add_start_docstrings(
    "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
)
class BartModel(PretrainedBartModel):
    def __init__(self, config: BartConfig):
        super().__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = BartEncoder(config, self.shared)
        self.decoder = BartDecoder(config, self.shared)

        self.init_weights()

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        encoder_outputs=None,  # type: Tuple
        decoder_attention_mask=None,
        decoder_cached_states=None,
804
        use_cache=False,
Sam Shleifer's avatar
Sam Shleifer committed
805
806
807
    ):

        # make masks if user doesn't supply
808
        if not use_cache:
809
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
810
811
812
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
813
814
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
Sam Shleifer's avatar
Sam Shleifer committed
815
            )
816
817
818
        else:
            decoder_padding_mask, causal_mask = None, None

Sam Shleifer's avatar
Sam Shleifer committed
819
820
        assert decoder_input_ids is not None
        if encoder_outputs is None:
821
            encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
Sam Shleifer's avatar
Sam Shleifer committed
822
        assert isinstance(encoder_outputs, tuple)
Sam Shleifer's avatar
Sam Shleifer committed
823
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
824
        decoder_outputs = self.decoder(
Sam Shleifer's avatar
Sam Shleifer committed
825
826
827
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
828
829
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
Sam Shleifer's avatar
Sam Shleifer committed
830
            decoder_cached_states=decoder_cached_states,
831
            use_cache=use_cache,
Sam Shleifer's avatar
Sam Shleifer committed
832
833
834
835
836
837
838
839
840
841
842
843
        )
        # Attention and hidden_states will be [] or None if they aren't needed
        decoder_outputs = _filter_out_falsey_values(decoder_outputs)  # type: tuple
        assert isinstance(decoder_outputs[0], torch.Tensor)
        encoder_outputs = _filter_out_falsey_values(encoder_outputs)  # type: tuple
        return decoder_outputs + encoder_outputs

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
844
845
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared
Sam Shleifer's avatar
Sam Shleifer committed
846
847

    def get_output_embeddings(self):
Sam Shleifer's avatar
Sam Shleifer committed
848
        return _make_linear_from_emb(self.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
849
850
851


@add_start_docstrings(
852
853
    "The BART Model with a language modeling head. Can be used for summarization.",
    BART_START_DOCSTRING + BART_GENERATION_EXAMPLE,
Sam Shleifer's avatar
Sam Shleifer committed
854
)
855
class BartForConditionalGeneration(PretrainedBartModel):
Sam Shleifer's avatar
Sam Shleifer committed
856
857
858
859
    base_model_prefix = "model"

    def __init__(self, config: BartConfig):
        super().__init__(config)
Sam Shleifer's avatar
Sam Shleifer committed
860
861
862
        base_model = BartModel(config)
        self.model = base_model

Sam Shleifer's avatar
Sam Shleifer committed
863
864
865
866
867
868
869
870
871
872
    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        decoder_cached_states=None,
        lm_labels=None,
873
        use_cache=False,
Sam Shleifer's avatar
Sam Shleifer committed
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        **unused
    ):
        r"""
        masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the masked language modeling loss.
            Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
            with labels
            in ``[0, ..., config.vocab_size]``.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
        masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

904
905
            # Mask filling only works for bart-large
            from transformers import BartTokenizer, BartForConditionalGeneration
Sam Shleifer's avatar
Sam Shleifer committed
906
            tokenizer = BartTokenizer.from_pretrained('bart-large')
907
908
909
910
911
912
913
914
915
            TXT = "My friends are <mask> but they eat too many carbs."
            model = BartForConditionalGeneration.from_pretrained('bart-large')
            input_ids = tokenizer.batch_encode_plus([TXT], return_tensors='pt')['input_ids']
            logits = model(input_ids)[0]
            masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
            probs = logits[0, masked_index].softmax(dim=0)
            values, predictions = probs.topk(5)
            tokenizer.decode(predictions).split()
            # ['good', 'great', 'all', 'really', 'very']
Sam Shleifer's avatar
Sam Shleifer committed
916
        """
917
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
918
919
920
921
922
923
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_cached_states=decoder_cached_states,
924
            use_cache=use_cache,
Sam Shleifer's avatar
Sam Shleifer committed
925
        )
926
        lm_logits = F.linear(outputs[0], self.model.shared.weight)
Sam Shleifer's avatar
Sam Shleifer committed
927
928
929
930
931
932
933
934
935
        outputs = (lm_logits,) + outputs[1:]  # Add hidden states and attention if they are here
        if lm_labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # TODO(SS): do we need to ignore pad tokens in lm_labels?
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1))
            outputs = (masked_lm_loss,) + outputs

        return outputs

936
937
938
939
940
941
    def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
        assert past is not None, "past has to be defined for encoder_outputs"

        # first step, decoder_cached_states are empty
        if not past[1]:
            encoder_outputs, decoder_cached_states = past, None
Patrick von Platen's avatar
Patrick von Platen committed
942
943
944
        else:
            encoder_outputs, decoder_cached_states = past
        return {
945
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
Patrick von Platen's avatar
Patrick von Platen committed
946
947
948
            "encoder_outputs": encoder_outputs,
            "decoder_cached_states": decoder_cached_states,
            "decoder_input_ids": decoder_input_ids,
949
            "attention_mask": attention_mask,
950
            "use_cache": True,  # change this to avoid caching (presumably for debugging)
Sam Shleifer's avatar
Sam Shleifer committed
951
952
        }

patrickvonplaten's avatar
patrickvonplaten committed
953
954
955
    def prepare_scores_for_generation(self, scores, cur_len, max_length):
        if cur_len == 1:
            self._force_token_ids_generation(scores, self.config.bos_token_id)
956
957
        if cur_len == max_length - 1 and self.config.eos_token_id is not None:
            self._force_token_ids_generation(scores, self.config.eos_token_id)
patrickvonplaten's avatar
patrickvonplaten committed
958
959
        return scores

Sam Shleifer's avatar
Sam Shleifer committed
960
961
962
963
964
965
966
    @staticmethod
    def _reorder_cache(past, beam_idx):
        ((enc_out, enc_mask), decoder_cached_states) = past
        reordered_past = []
        for layer_past in decoder_cached_states:
            # get the correct batch idx from decoder layer's batch dim for cross and self-attn
            layer_past_new = {
967
                attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
Sam Shleifer's avatar
Sam Shleifer committed
968
969
            }
            reordered_past.append(layer_past_new)
970
971

        new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
Sam Shleifer's avatar
Sam Shleifer committed
972
973
974
975
        new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)

        past = ((new_enc_out, new_enc_mask), reordered_past)
        return past
Sam Shleifer's avatar
Sam Shleifer committed
976

977
978
979
    def get_encoder(self):
        return self.model.encoder

Sam Shleifer's avatar
Sam Shleifer committed
980
    def get_output_embeddings(self):
981
        return _make_linear_from_emb(self.model.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
982

983
984
985
986
    def _do_output_past(self, *args, **kwargs):
        """ We should always use the cache in generate."""
        return True

Sam Shleifer's avatar
Sam Shleifer committed
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020

@add_start_docstrings(
    """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
    BART_START_DOCSTRING,
)
class BartForSequenceClassification(PretrainedBartModel):
    def __init__(self, config: BartConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.model = BartModel(config)
        self.classification_head = BartClassificationHead(
            config.d_model, config.d_model, config.num_labels, config.classif_dropout,
        )
        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        labels=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
            loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Patrick von Platen's avatar
Patrick von Platen committed
1021
                Classification loss (cross entropy)
Sam Shleifer's avatar
Sam Shleifer committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
                Classification (or regression if config.num_labels==1) scores (before SoftMax).
            hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
                Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
                of shape :obj:`(batch_size, sequence_length, hidden_size)`.
                Hidden-states of the model at the output of each layer plus the initial embedding outputs.
            attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
                Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
                Attentions weights after the attention softmax, used to compute the weighted average in the
                self-attention
                heads.

    Examples::

        from transformers import BartTokenizer, BartForSequenceClassification
        import torch

        tokenizer = BartTokenizer.from_pretrained('bart-large')
        model = BartForSequenceClassification.from_pretrained('bart-large')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute",
        add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]

        """
1048
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
1049
1050
1051
1052
1053
1054
1055
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
        )
        x = outputs[0]  # last hidden state
1056
        eos_mask = input_ids.eq(self.config.eos_token_id)
Sam Shleifer's avatar
Sam Shleifer committed
1057
1058
1059
1060
1061
1062
1063
        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
        logits = self.classification_head(sentence_representation)
        # Prepend logits
        outputs = (logits,) + outputs[1:]  # Add hidden states and attention if they are here
        if labels is not None:  # prepend loss to output,
1064
            loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
Sam Shleifer's avatar
Sam Shleifer committed
1065
1066
1067
            outputs = (loss,) + outputs

        return outputs