modeling_bart.py 56.3 KB
Newer Older
Sam Shleifer's avatar
Sam Shleifer committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
Sam Shleifer's avatar
Sam Shleifer committed
16
import math
Sam Shleifer's avatar
Sam Shleifer committed
17
import random
Sylvain Gugger's avatar
Sylvain Gugger committed
18
import warnings
Sam Shleifer's avatar
Sam Shleifer committed
19
20
from typing import Dict, List, Optional, Tuple

21
import numpy as np
Sam Shleifer's avatar
Sam Shleifer committed
22
23
24
import torch
import torch.nn.functional as F
from torch import Tensor, nn
Suraj Patil's avatar
Suraj Patil committed
25
from torch.nn import CrossEntropyLoss
Sam Shleifer's avatar
Sam Shleifer committed
26

27
from .activations import ACT2FN
Sam Shleifer's avatar
Sam Shleifer committed
28
from .configuration_bart import BartConfig
29
30
31
32
33
from .file_utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
34
35
36
37
38
39
40
41
42
    replace_return_docstrings,
)
from .modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPast,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,
43
)
44
from .modeling_utils import PreTrainedModel
Lysandre Debut's avatar
Lysandre Debut committed
45
from .utils import logging
Sam Shleifer's avatar
Sam Shleifer committed
46
47


Lysandre Debut's avatar
Lysandre Debut committed
48
logger = logging.get_logger(__name__)
Sam Shleifer's avatar
Sam Shleifer committed
49

50
_CONFIG_FOR_DOC = "BartConfig"
51
52
_TOKENIZER_FOR_DOC = "BartTokenizer"

Sam Shleifer's avatar
Sam Shleifer committed
53

54
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
55
    "facebook/bart-base",
56
57
58
59
60
61
    "facebook/bart-large",
    "facebook/bart-large-mnli",
    "facebook/bart-large-cnn",
    "facebook/bart-large-xsum",
    "facebook/mbart-large-en-ro",
]
62
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
63

Sam Shleifer's avatar
Sam Shleifer committed
64
65
66
67
68
69
70
71
72
73
74

BART_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matters related to general usage and behavior.

    Parameters:
        config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.

75
76
"""
BART_GENERATION_EXAMPLE = r"""
77
    Summarization example::
78
79

        from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
80

81
        # see ``examples/summarization/bart/run_eval.py`` for a longer example
Zihao Fu's avatar
Zihao Fu committed
82
83
        model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
84

85
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
86
87
        inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')

88
        # Generate Summary
89
        summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
90
91
        print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

Sam Shleifer's avatar
Sam Shleifer committed
92
93
94
95
96
97
98
99
100
101
102
103
"""

BART_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
               Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them.
            Padding will be ignored by default should you provide it.
            Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`.
        attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Mask to avoid performing attention on padding token indices in input_ids.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
104
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Patrick von Platen's avatar
Patrick von Platen committed
105
106
107
            Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
            `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
            Used in the cross-attention of the decoder.
Sam Shleifer's avatar
Sam Shleifer committed
108
109
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
            Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
110
111
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
            Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
Sam Shleifer's avatar
Sam Shleifer committed
112
113
            If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
            See diagram 1 in the paper for more info on the default strategy
114
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
115
116
            Contains pre-computed key and value hidden-states of the attention blocks.
            Can be used to speed up decoding.
117
            If ``past_key_values`` are used, the user can optionally input only the last
118
119
120
            ``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
            :obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
121
122
            If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
            ``past_key_values``).
ZhuBaohe's avatar
ZhuBaohe committed
123
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
124
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
125
126
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
127
128
129
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
            plain tuple.
Sam Shleifer's avatar
Sam Shleifer committed
130
"""
131
132
133


def invert_mask(attention_mask):
134
    """Turns 1->0, 0->1, False->True, True-> False"""
135
136
    assert attention_mask.dim() == 2
    return attention_mask.eq(0)
Sam Shleifer's avatar
Sam Shleifer committed
137
138
139


def _prepare_bart_decoder_inputs(
140
    config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
Sam Shleifer's avatar
Sam Shleifer committed
141
):
142
    """Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if
Sam Shleifer's avatar
Sam Shleifer committed
143
    none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
144
    Note: this is not called during generation
Sam Shleifer's avatar
Sam Shleifer committed
145
146
147
148
    """
    pad_token_id = config.pad_token_id
    if decoder_input_ids is None:
        decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
149
150
    bsz, tgt_len = decoder_input_ids.size()
    if decoder_padding_mask is None:
Sam Shleifer's avatar
Sam Shleifer committed
151
        decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
152
153
    else:
        decoder_padding_mask = invert_mask(decoder_padding_mask)
154
155
156
    if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
        # never mask leading token, even if it is pad
        decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
157
158
159
160
    causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
        dtype=causal_mask_dtype, device=decoder_input_ids.device
    )
    return decoder_input_ids, decoder_padding_mask, causal_mask
Sam Shleifer's avatar
Sam Shleifer committed
161
162
163
164
165
166
167
168
169
170
171
172


class PretrainedBartModel(PreTrainedModel):
    config_class = BartConfig
    base_model_prefix = "model"

    def _init_weights(self, module):
        std = self.config.init_std
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
173
174
175
        elif isinstance(module, SinusoidalPositionalEmbedding):
            pass
        elif isinstance(module, nn.Embedding):
Sam Shleifer's avatar
Sam Shleifer committed
176
177
178
179
180
181
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @property
    def dummy_inputs(self):
182
        pad_token = self.config.pad_token_id
183
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
Sam Shleifer's avatar
Sam Shleifer committed
184
185
186
187
188
189
190
191
192
193
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs


def _make_linear_from_emb(emb):
    vocab_size, emb_size = emb.weight.shape
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
194
    lin_layer.weight.data = emb.weight.data
Sam Shleifer's avatar
Sam Shleifer committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    return lin_layer


# Helper Functions, mostly for making masks
def _check_shapes(shape_1, shape2):
    if shape_1 != shape2:
        raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2))


def shift_tokens_right(input_ids, pad_token_id):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens


def make_padding_mask(input_ids, padding_idx=1):
    """True for pad tokens"""
    padding_mask = input_ids.eq(padding_idx)
    if not padding_mask.any():
        padding_mask = None
    return padding_mask


# Helper Modules


class EncoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
228
        self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
Sam Shleifer's avatar
Sam Shleifer committed
229
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
230
231
        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
        self.dropout = config.dropout
232
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
233
234
235
236
237
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

238
    def forward(self, x, encoder_padding_mask, output_attentions=False):
Sam Shleifer's avatar
Sam Shleifer committed
239
240
241
242
243
244
245
246
247
248
249
250
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            for t_tgt, t_src is excluded (or masked out), =0 means it is
            included in attention

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
251
252
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
253
        x, attn_weights = self.self_attn(
254
            query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
255
        )
Sam Shleifer's avatar
Sam Shleifer committed
256
257
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
258
259
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
260
261

        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
262
263
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
264
265
266
267
268
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
269
270
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        return x, attn_weights


class BartEncoder(nn.Module):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer
    is a :class:`EncoderLayer`.

    Args:
        config: BartConfig
    """

    def __init__(self, config: BartConfig, embed_tokens):
        super().__init__()

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = embed_tokens.embedding_dim
Sam Shleifer's avatar
Sam Shleifer committed
290
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
291
292
293
294
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        self.embed_tokens = embed_tokens
295
296
297
298
299
300
        if config.static_position_embeddings:
            self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, embed_dim, self.padding_idx
            )
        else:
            self.embed_positions = LearnedPositionalEmbedding(
Lysandre's avatar
Lysandre committed
301
302
303
304
                config.max_position_embeddings,
                embed_dim,
                self.padding_idx,
                config.extra_pos_embeddings,
305
            )
Sam Shleifer's avatar
Sam Shleifer committed
306
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
307
        self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
Sam Shleifer's avatar
Sam Shleifer committed
308
309
        # mbart has one extra layer_norm
        self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
Sam Shleifer's avatar
Sam Shleifer committed
310

311
    def forward(
312
        self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
313
    ):
Sam Shleifer's avatar
Sam Shleifer committed
314
315
316
317
318
319
        """
        Args:
            input_ids (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            attention_mask (torch.LongTensor): indicating which indices are padding tokens.
        Returns:
320
            BaseModelOutput or Tuple comprised of:
Sam Shleifer's avatar
Sam Shleifer committed
321
322
                - **x** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
323
                - **encoder_states** (tuple(torch.FloatTensor)): all intermediate
Sam Shleifer's avatar
Sam Shleifer committed
324
                  hidden states of shape `(src_len, batch, embed_dim)`.
Joseph Liu's avatar
Joseph Liu committed
325
                  Only populated if *output_hidden_states:* is True.
326
                - **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
Sam Shleifer's avatar
Sam Shleifer committed
327
328
                During training might not be of length n_layers because of layer dropout.
        """
329
330
        # check attention mask and invert
        if attention_mask is not None:
331
            attention_mask = invert_mask(attention_mask)
332

Sam Shleifer's avatar
Sam Shleifer committed
333
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
334
335
336
337
338
339
340
341
        embed_pos = self.embed_positions(input_ids)
        x = inputs_embeds + embed_pos
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

342
343
        encoder_states = [] if output_hidden_states else None
        all_attentions = () if output_attentions else None
Sam Shleifer's avatar
Sam Shleifer committed
344
        for encoder_layer in self.layers:
Joseph Liu's avatar
Joseph Liu committed
345
            if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
346
347
348
349
350
351
                encoder_states.append(x)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):  # skip the layer
                attn = None
            else:
352
                x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
Sam Shleifer's avatar
Sam Shleifer committed
353

354
            if output_attentions:
355
                all_attentions = all_attentions + (attn,)
Sam Shleifer's avatar
Sam Shleifer committed
356

Sam Shleifer's avatar
Sam Shleifer committed
357
358
        if self.layer_norm:
            x = self.layer_norm(x)
Joseph Liu's avatar
Joseph Liu committed
359
        if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
360
            encoder_states.append(x)
361
362
            # T x B x C -> B x T x C
            encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
Sam Shleifer's avatar
Sam Shleifer committed
363

364
365
366
        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

367
        if not return_dict:
368
369
            return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
Sam Shleifer's avatar
Sam Shleifer committed
370
371
372
373
374
375


class DecoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model
376
377

        self.self_attn = Attention(
Lysandre's avatar
Lysandre committed
378
379
380
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
Sam Shleifer's avatar
Sam Shleifer committed
381
382
        )
        self.dropout = config.dropout
383
        self.activation_fn = ACT2FN[config.activation_function]
Sam Shleifer's avatar
Sam Shleifer committed
384
        self.activation_dropout = config.activation_dropout
Sam Shleifer's avatar
Sam Shleifer committed
385
        self.normalize_before = config.normalize_before
Sam Shleifer's avatar
Sam Shleifer committed
386
387

        self.self_attn_layer_norm = LayerNorm(self.embed_dim)
388
        self.encoder_attn = Attention(
Sam Shleifer's avatar
Sam Shleifer committed
389
390
391
392
393
394
395
396
397
398
399
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = LayerNorm(self.embed_dim)

    def forward(
400
401
402
403
404
405
406
        self,
        x,
        encoder_hidden_states,
        encoder_attn_mask=None,
        layer_state=None,
        causal_mask=None,
        decoder_padding_mask=None,
407
        output_attentions=False,
Sam Shleifer's avatar
Sam Shleifer committed
408
409
410
    ):
        residual = x

Sam Shleifer's avatar
Sam Shleifer committed
411
412
        if layer_state is None:
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
413
414
415
416
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        # Self Attention

417
        x, self_attn_weights = self.self_attn(
418
419
            query=x,
            key=x,
Sam Shleifer's avatar
Sam Shleifer committed
420
            layer_state=layer_state,  # adds keys to layer state
421
422
            key_padding_mask=decoder_padding_mask,
            attn_mask=causal_mask,
423
            output_attentions=output_attentions,
424
        )
Sam Shleifer's avatar
Sam Shleifer committed
425
426
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
427
428
429
430
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        # Cross attention
Sam Shleifer's avatar
Sam Shleifer committed
431
432
        residual = x
        assert self.encoder_attn.cache_key != self.self_attn.cache_key
Sam Shleifer's avatar
Sam Shleifer committed
433
434
        if self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
435
        x, _ = self.encoder_attn(
Sam Shleifer's avatar
Sam Shleifer committed
436
            query=x,
437
            key=encoder_hidden_states,
Sam Shleifer's avatar
Sam Shleifer committed
438
            key_padding_mask=encoder_attn_mask,
Sam Shleifer's avatar
Sam Shleifer committed
439
            layer_state=layer_state,  # mutates layer state
Sam Shleifer's avatar
Sam Shleifer committed
440
441
442
        )
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
443
444
        if not self.normalize_before:
            x = self.encoder_attn_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
445

Sam Shleifer's avatar
Sam Shleifer committed
446
        # Fully Connected
Sam Shleifer's avatar
Sam Shleifer committed
447
        residual = x
Sam Shleifer's avatar
Sam Shleifer committed
448
449
        if self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
450
451
452
453
454
        x = self.activation_fn(self.fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
Sam Shleifer's avatar
Sam Shleifer committed
455
456
        if not self.normalize_before:
            x = self.final_layer_norm(x)
Sam Shleifer's avatar
Sam Shleifer committed
457
458
459
        return (
            x,
            self_attn_weights,
Sam Shleifer's avatar
Sam Shleifer committed
460
461
            layer_state,
        )  # just self_attn weights for now, following t5, layer_state = cache for decoding
Sam Shleifer's avatar
Sam Shleifer committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


class BartDecoder(nn.Module):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer
    is a :class:`DecoderLayer`.
    Args:
        config: BartConfig
        embed_tokens (torch.nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
        super().__init__()
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = config.max_position_embeddings
Sam Shleifer's avatar
Sam Shleifer committed
479
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
Sam Shleifer's avatar
Sam Shleifer committed
480
        self.embed_tokens = embed_tokens
481
482
483
484
485
486
        if config.static_position_embeddings:
            self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, config.d_model, config.pad_token_id
            )
        else:
            self.embed_positions = LearnedPositionalEmbedding(
Lysandre's avatar
Lysandre committed
487
488
489
490
                config.max_position_embeddings,
                config.d_model,
                self.padding_idx,
                config.extra_pos_embeddings,
491
            )
Sam Shleifer's avatar
Sam Shleifer committed
492
493
494
        self.layers = nn.ModuleList(
            [DecoderLayer(config) for _ in range(config.decoder_layers)]
        )  # type: List[DecoderLayer]
495
        self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
Sam Shleifer's avatar
Sam Shleifer committed
496
        self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
Sam Shleifer's avatar
Sam Shleifer committed
497
498
499
500
501
502

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        encoder_padding_mask,
503
504
        decoder_padding_mask,
        decoder_causal_mask,
505
        past_key_values=None,
506
        use_cache=False,
507
        output_attentions=False,
Joseph Liu's avatar
Joseph Liu committed
508
        output_hidden_states=False,
509
        return_dict=False,
510
        **unused,
Sam Shleifer's avatar
Sam Shleifer committed
511
512
513
514
515
516
517
518
519
520
521
    ):
        """
        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            input_ids (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_hidden_states: output from the encoder, used for
                encoder-side attention
            encoder_padding_mask: for ignoring pad tokens
522
            past_key_values (dict or None): dictionary used for storing state during generation
Sam Shleifer's avatar
Sam Shleifer committed
523
524

        Returns:
525
            BaseModelOutputWithPast or tuple:
Sam Shleifer's avatar
Sam Shleifer committed
526
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
527
                - the cache
Sam Shleifer's avatar
Sam Shleifer committed
528
529
530
                - hidden states
                - attentions
        """
531
532
        if "decoder_cached_states" in unused:
            warnings.warn(
533
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
534
535
                FutureWarning,
            )
536
537
538
539
540
541
542
            past_key_values = unused.pop("decoder_cached_states")
        if "decoder_past_key_values" in unused:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = unused.pop("decoder_past_key_values")
543

544
545
        # check attention mask and invert
        if encoder_padding_mask is not None:
546
            encoder_padding_mask = invert_mask(encoder_padding_mask)
547

Sam Shleifer's avatar
Sam Shleifer committed
548
        # embed positions
549
        positions = self.embed_positions(input_ids, use_cache=use_cache)
Sam Shleifer's avatar
Sam Shleifer committed
550

551
        if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
552
553
            input_ids = input_ids[:, -1:]
            positions = positions[:, -1:]  # happens after we embed them
554
            # assert input_ids.ne(self.padding_idx).any()
Sam Shleifer's avatar
Sam Shleifer committed
555

Sam Shleifer's avatar
Sam Shleifer committed
556
        x = self.embed_tokens(input_ids) * self.embed_scale
Sam Shleifer's avatar
Sam Shleifer committed
557
        x += positions
Sam Shleifer's avatar
Sam Shleifer committed
558
559
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
560
561
562
563
564

        # Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
        x = x.transpose(0, 1)
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

Sam Shleifer's avatar
Sam Shleifer committed
565
        # decoder layers
566
567
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
Sam Shleifer's avatar
Sam Shleifer committed
568
        next_decoder_cache = []
Sam Shleifer's avatar
Sam Shleifer committed
569
        for idx, decoder_layer in enumerate(self.layers):
Sam Shleifer's avatar
Sam Shleifer committed
570
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
Joseph Liu's avatar
Joseph Liu committed
571
            if output_hidden_states:
Sam Shleifer's avatar
Sam Shleifer committed
572
                all_hidden_states += (x,)
Sam Shleifer's avatar
Sam Shleifer committed
573
            dropout_probability = random.uniform(0, 1)
574
            if self.training and (dropout_probability < self.layerdrop):
Sam Shleifer's avatar
Sam Shleifer committed
575
                continue
Sam Shleifer's avatar
Sam Shleifer committed
576

577
            layer_state = past_key_values[idx] if past_key_values is not None else None
Sam Shleifer's avatar
Sam Shleifer committed
578

579
            x, layer_self_attn, layer_past = decoder_layer(
580
581
582
583
584
585
                x,
                encoder_hidden_states,
                encoder_attn_mask=encoder_padding_mask,
                decoder_padding_mask=decoder_padding_mask,
                layer_state=layer_state,
                causal_mask=decoder_causal_mask,
586
                output_attentions=output_attentions,
Sam Shleifer's avatar
Sam Shleifer committed
587
            )
Sam Shleifer's avatar
Sam Shleifer committed
588

589
            if use_cache:
Sam Shleifer's avatar
Sam Shleifer committed
590
                next_decoder_cache.append(layer_past.copy())
Sam Shleifer's avatar
Sam Shleifer committed
591

592
            if self.layer_norm and (idx == len(self.layers) - 1):  # if config.add_final_layer_norm (mBART)
Sam Shleifer's avatar
Sam Shleifer committed
593
                x = self.layer_norm(x)
594
            if output_attentions:
Sam Shleifer's avatar
Sam Shleifer committed
595
596
                all_self_attns += (layer_self_attn,)

Sam Shleifer's avatar
Sam Shleifer committed
597
        # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
598
599
        if output_hidden_states:
            all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
Sam Shleifer's avatar
Sam Shleifer committed
600
        x = x.transpose(0, 1)
601
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
Sam Shleifer's avatar
Sam Shleifer committed
602

603
        next_cache = next_decoder_cache if use_cache else None
604

605
        if not return_dict:
606
607
608
609
            return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
        )
Sam Shleifer's avatar
Sam Shleifer committed
610
611


612
613
def _reorder_buffer(attn_cache, new_order):
    for k, input_buffer_k in attn_cache.items():
Sam Shleifer's avatar
Sam Shleifer committed
614
        if input_buffer_k is not None:
615
616
            attn_cache[k] = input_buffer_k.index_select(0, new_order)
    return attn_cache
Sam Shleifer's avatar
Sam Shleifer committed
617
618


619
class Attention(nn.Module):
620
    """Multi-headed attention from 'Attention Is All You Need' paper"""
Sam Shleifer's avatar
Sam Shleifer committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638

    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        encoder_decoder_attention=False,  # otherwise self_attention
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.encoder_decoder_attention = encoder_decoder_attention
639
640
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
Sam Shleifer's avatar
Sam Shleifer committed
641
642
643
644
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"

645
646
    def _shape(self, tensor, seq_len, bsz):
        return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
Sam Shleifer's avatar
Sam Shleifer committed
647
648
649
650
651
652

    def forward(
        self,
        query,
        key: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
653
        layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
Sam Shleifer's avatar
Sam Shleifer committed
654
        attn_mask: Optional[Tensor] = None,
655
        output_attentions=False,
Sam Shleifer's avatar
Sam Shleifer committed
656
    ) -> Tuple[Tensor, Optional[Tensor]]:
657
        """Input shape: Time(SeqLen) x Batch x Channel"""
658
        static_kv: bool = self.encoder_decoder_attention
Sam Shleifer's avatar
Sam Shleifer committed
659
660
661
662
        tgt_len, bsz, embed_dim = query.size()
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]
        # get here for encoder decoder cause of static_kv
663
        if layer_state is not None:  # reuse k,v and encoder_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
664
            saved_state = layer_state.get(self.cache_key, {})
665
            if "prev_key" in saved_state and static_kv:
Sam Shleifer's avatar
Sam Shleifer committed
666
                # previous time steps are cached - no need to recompute key and value if they are static
667
                key = None
Sam Shleifer's avatar
Sam Shleifer committed
668
669
        else:
            saved_state = None
Sam Shleifer's avatar
Sam Shleifer committed
670
            layer_state = {}
Sam Shleifer's avatar
Sam Shleifer committed
671
672

        q = self.q_proj(query) * self.scaling
673
        if static_kv:
Sam Shleifer's avatar
Sam Shleifer committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
            if key is None:
                k = v = None
            else:
                k = self.k_proj(key)
                v = self.v_proj(key)
        else:
            k = self.k_proj(query)
            v = self.v_proj(query)

        q = self._shape(q, tgt_len, bsz)
        if k is not None:
            k = self._shape(k, -1, bsz)
        if v is not None:
            v = self._shape(v, -1, bsz)

        if saved_state is not None:
Sam Shleifer's avatar
Sam Shleifer committed
690
691
692
693
694
695
696
697
698
            k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)

        # Update cache
        layer_state[self.cache_key] = {
            "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
            "prev_key_padding_mask": key_padding_mask if not static_kv else None,
        }

Sam Shleifer's avatar
Sam Shleifer committed
699
700
701
702
703
704
705
706
707
708
709
710
        assert k is not None
        src_len = k.size(1)
        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)

        if attn_mask is not None:
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
        if key_padding_mask is not None and key_padding_mask.dim() == 0:
            key_padding_mask = None
Lysandre's avatar
Lysandre committed
711
712
713
714
        assert key_padding_mask is None or key_padding_mask.size()[:2] == (
            bsz,
            src_len,
        )
Sam Shleifer's avatar
Sam Shleifer committed
715
716
717

        if key_padding_mask is not None:  # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
718
            reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
Sam Shleifer's avatar
Sam Shleifer committed
719
720
            attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
sshleifer's avatar
sshleifer committed
721
        attn_weights = F.softmax(attn_weights, dim=-1)
Lysandre's avatar
Lysandre committed
722
723
724
725
726
        attn_probs = F.dropout(
            attn_weights,
            p=self.dropout,
            training=self.training,
        )
sshleifer's avatar
sshleifer committed
727

Sam Shleifer's avatar
Sam Shleifer committed
728
729
730
731
732
        assert v is not None
        attn_output = torch.bmm(attn_probs, v)
        assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)
733
        if output_attentions:
734
735
736
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
        else:
            attn_weights = None
Sam Shleifer's avatar
Sam Shleifer committed
737
738
        return attn_output, attn_weights

Sam Shleifer's avatar
Sam Shleifer committed
739
    def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
Sam Shleifer's avatar
Sam Shleifer committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
        if "prev_key" in saved_state:
            _prev_key = saved_state["prev_key"]
            assert _prev_key is not None
            prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                k = prev_key
            else:
                assert k is not None
                k = torch.cat([prev_key, k], dim=1)
        if "prev_value" in saved_state:
            _prev_value = saved_state["prev_value"]
            assert _prev_value is not None
            prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
            if static_kv:
                v = prev_value
            else:
                assert v is not None
                v = torch.cat([prev_value, v], dim=1)
        assert k is not None and v is not None
760
        prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
761
762
763
764
765
        if prev_key_padding_mask is not None:
            if static_kv:
                new_key_padding_mask = prev_key_padding_mask
            else:
                new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
Sam Shleifer's avatar
Sam Shleifer committed
766
        else:
767
768
            new_key_padding_mask = key_padding_mask
        return k, v, new_key_padding_mask
Sam Shleifer's avatar
Sam Shleifer committed
769
770
771
772
773
774
775
776


class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    # This can trivially be shared with RobertaClassificationHead

    def __init__(
Lysandre's avatar
Lysandre committed
777
778
779
780
781
        self,
        input_dim,
        inner_dim,
        num_classes,
        pooler_dropout,
Sam Shleifer's avatar
Sam Shleifer committed
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    Padding ids are ignored by either offsetting based on padding_idx
    or by setting padding_idx to None and ensuring that the appropriate
    position ids are passed to the forward function.
    """

805
806
807
808
    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
        # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models dont have this hack
        self.offset = offset
Sam Shleifer's avatar
Sam Shleifer committed
809
        assert padding_idx is not None
810
        num_embeddings += offset
Sam Shleifer's avatar
Sam Shleifer committed
811
812
        super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)

813
    def forward(self, input_ids, use_cache=False):
Sam Shleifer's avatar
Sam Shleifer committed
814
        """Input is expected to be of size [bsz x seqlen]."""
815
816
817
        bsz, seq_len = input_ids.shape[:2]
        if use_cache:
            positions = input_ids.data.new(1, 1).fill_(seq_len - 1)  # called before slicing
Sam Shleifer's avatar
Sam Shleifer committed
818
        else:
819
820
821
            # starts at 0, ends at 1-seq_len
            positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions + self.offset)
Sam Shleifer's avatar
Sam Shleifer committed
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840


def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
    if torch.cuda.is_available():
        try:
            from apex.normalization import FusedLayerNorm

            return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
        except ImportError:
            pass
    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


def fill_with_neg_inf(t):
    """FP16-compatible function that fills a input_ids with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)


# Public API
841
842
def _get_shape(t):
    return getattr(t, "shape", None)
Sam Shleifer's avatar
Sam Shleifer committed
843
844
845


@add_start_docstrings(
Lysandre's avatar
Lysandre committed
846
847
    "The bare BART Model outputting raw hidden-states without any specific head on top.",
    BART_START_DOCSTRING,
Sam Shleifer's avatar
Sam Shleifer committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
)
class BartModel(PretrainedBartModel):
    def __init__(self, config: BartConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)

        self.encoder = BartEncoder(config, self.shared)
        self.decoder = BartDecoder(config, self.shared)

        self.init_weights()

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
862
863
864
865
866
867
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=BaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
Sam Shleifer's avatar
Sam Shleifer committed
868
869
870
871
872
    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
873
        encoder_outputs: Optional[Tuple] = None,
Sam Shleifer's avatar
Sam Shleifer committed
874
        decoder_attention_mask=None,
875
        past_key_values=None,
876
        use_cache=None,
877
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
878
        output_hidden_states=None,
879
        return_dict=None,
880
        **kwargs,
Sam Shleifer's avatar
Sam Shleifer committed
881
    ):
882
883
884
885
886
887
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
888
889
890
891

        if decoder_input_ids is None:
            use_cache = False

892
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
893
894
895
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
896
        use_cache = use_cache if use_cache is not None else self.config.use_cache
897
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sam Shleifer's avatar
Sam Shleifer committed
898
899

        # make masks if user doesn't supply
900
        if not use_cache:
901
            decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
902
903
904
                self.config,
                input_ids,
                decoder_input_ids=decoder_input_ids,
905
906
                decoder_padding_mask=decoder_attention_mask,
                causal_mask_dtype=self.shared.weight.dtype,
Sam Shleifer's avatar
Sam Shleifer committed
907
            )
908
909
910
        else:
            decoder_padding_mask, causal_mask = None, None

Sam Shleifer's avatar
Sam Shleifer committed
911
        assert decoder_input_ids is not None
912

Sam Shleifer's avatar
Sam Shleifer committed
913
        if encoder_outputs is None:
914
            encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
915
916
917
918
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
919
                return_dict=return_dict,
920
            )
921
922
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_dict=False
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
923
924
925
926
927
928
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

Sam Shleifer's avatar
Sam Shleifer committed
929
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
930
        decoder_outputs = self.decoder(
Sam Shleifer's avatar
Sam Shleifer committed
931
932
933
            decoder_input_ids,
            encoder_outputs[0],
            attention_mask,
934
935
            decoder_padding_mask,
            decoder_causal_mask=causal_mask,
936
            past_key_values=past_key_values,
937
            use_cache=use_cache,
938
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
939
            output_hidden_states=output_hidden_states,
940
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
941
        )
942

943
        if not return_dict:
944
945
946
947
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
948
            past_key_values=decoder_outputs.past_key_values,
949
950
951
952
953
954
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
Sam Shleifer's avatar
Sam Shleifer committed
955
956
957
958
959
960

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, value):
        self.shared = value
961
962
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared
Sam Shleifer's avatar
Sam Shleifer committed
963
964

    def get_output_embeddings(self):
Sam Shleifer's avatar
Sam Shleifer committed
965
        return _make_linear_from_emb(self.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
966
967
968


@add_start_docstrings(
969
    "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
Sam Shleifer's avatar
Sam Shleifer committed
970
)
971
class BartForConditionalGeneration(PretrainedBartModel):
Sam Shleifer's avatar
Sam Shleifer committed
972
    base_model_prefix = "model"
973
    authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
Sam Shleifer's avatar
Sam Shleifer committed
974
975
976

    def __init__(self, config: BartConfig):
        super().__init__(config)
Sam Shleifer's avatar
Sam Shleifer committed
977
978
        base_model = BartModel(config)
        self.model = base_model
979
980
981
982
983
984
985
986
987
988
989
990
991
        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))

    def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
        old_num_tokens = self.model.shared.num_embeddings
        new_embeddings = super().resize_token_embeddings(new_num_tokens)
        self.model.shared = new_embeddings
        self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
        return new_embeddings

    def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None:
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
992
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
993
994
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        self.register_buffer("final_logits_bias", new_bias)
Sam Shleifer's avatar
Sam Shleifer committed
995

Sam Shleifer's avatar
Sam Shleifer committed
996
    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
997
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
998
    @add_end_docstrings(BART_GENERATION_EXAMPLE)
Sam Shleifer's avatar
Sam Shleifer committed
999
1000
1001
1002
1003
1004
1005
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1006
        past_key_values=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1007
        labels=None,
1008
        use_cache=None,
1009
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1010
        output_hidden_states=None,
1011
        return_dict=None,
1012
        **unused,
Sam Shleifer's avatar
Sam Shleifer committed
1013
1014
    ):
        r"""
Lysandre's avatar
Lysandre committed
1015
1016
1017
1018
1019
            labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
                Labels for computing the masked language modeling loss.
                Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
                Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
                with labels in ``[0, ..., config.vocab_size]``.
Sam Shleifer's avatar
Sam Shleifer committed
1020

Lysandre's avatar
Lysandre committed
1021
        Returns:
Sam Shleifer's avatar
Sam Shleifer committed
1022

Lysandre's avatar
Lysandre committed
1023
        Conditional generation example::
Sam Shleifer's avatar
Sam Shleifer committed
1024

Lysandre's avatar
Lysandre committed
1025
1026
1027
1028
                # Mask filling only works for bart-large
                from transformers import BartTokenizer, BartForConditionalGeneration
                tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
                TXT = "My friends are <mask> but they eat too many carbs."
1029

Lysandre's avatar
Lysandre committed
1030
1031
1032
                model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
                input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
                logits = model(input_ids).logits
1033

Lysandre's avatar
Lysandre committed
1034
1035
1036
                masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
                probs = logits[0, masked_index].softmax(dim=0)
                values, predictions = probs.topk(5)
1037

Lysandre's avatar
Lysandre committed
1038
1039
                tokenizer.decode(predictions).split()
                # ['good', 'great', 'all', 'really', 'very']
Sam Shleifer's avatar
Sam Shleifer committed
1040
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1041
1042
1043
        if "lm_labels" in unused:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1044
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
1045
1046
            )
            labels = unused.pop("lm_labels")
1047
1048
        if "decoder_cached_states" in unused:
            warnings.warn(
1049
                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1050
1051
                FutureWarning,
            )
1052
1053
1054
1055
1056
1057
1058
            past_key_values = unused.pop("decoder_cached_states")
        if "decoder_past_key_values" in unused:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = unused.pop("decoder_past_key_values")
1059
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
1060

1061
1062
        if labels is not None:
            use_cache = False
1063
1064
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
1065

1066
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
1067
1068
1069
1070
1071
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
1072
            past_key_values=past_key_values,
1073
            use_cache=use_cache,
1074
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1075
            output_hidden_states=output_hidden_states,
1076
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
1077
        )
1078
        lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
1079
1080

        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
        if labels is not None:
1082
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
1083
1084
            # TODO(SS): do we need to ignore pad tokens in labels?
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
Sam Shleifer's avatar
Sam Shleifer committed
1085

1086
        if not return_dict:
1087
1088
1089
1090
1091
1092
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
1093
            past_key_values=outputs.past_key_values,
1094
1095
1096
1097
1098
1099
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Sam Shleifer's avatar
Sam Shleifer committed
1100

1101
1102
1103
    def prepare_inputs_for_generation(
        self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1104
        return {
1105
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
Patrick von Platen's avatar
Patrick von Platen committed
1106
            "encoder_outputs": encoder_outputs,
1107
            "past_key_values": past,
Patrick von Platen's avatar
Patrick von Platen committed
1108
            "decoder_input_ids": decoder_input_ids,
1109
            "attention_mask": attention_mask,
1110
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
Sam Shleifer's avatar
Sam Shleifer committed
1111
1112
        }

1113
    def adjust_logits_during_generation(self, logits, cur_len, max_length):
1114
        if cur_len == 1 and self.config.force_bos_token_to_be_generated:
1115
            self._force_token_ids_generation(logits, self.config.bos_token_id)
1116
        elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
1117
1118
            self._force_token_ids_generation(logits, self.config.eos_token_id)
        return logits
patrickvonplaten's avatar
patrickvonplaten committed
1119

1120
1121
1122
    def _force_token_ids_generation(self, scores, token_id) -> None:
        """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
        scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
1123

Sam Shleifer's avatar
Sam Shleifer committed
1124
1125
1126
    @staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = []
1127
        for layer_past in past:
Sam Shleifer's avatar
Sam Shleifer committed
1128
1129
            # get the correct batch idx from decoder layer's batch dim for cross and self-attn
            layer_past_new = {
1130
                attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
Sam Shleifer's avatar
Sam Shleifer committed
1131
1132
            }
            reordered_past.append(layer_past_new)
1133
        return reordered_past
Sam Shleifer's avatar
Sam Shleifer committed
1134

1135
1136
1137
    def get_encoder(self):
        return self.model.encoder

Sam Shleifer's avatar
Sam Shleifer committed
1138
    def get_output_embeddings(self):
1139
        return _make_linear_from_emb(self.model.shared)  # make it on the fly
Sam Shleifer's avatar
Sam Shleifer committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150


@add_start_docstrings(
    """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
    BART_START_DOCSTRING,
)
class BartForSequenceClassification(PretrainedBartModel):
    def __init__(self, config: BartConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.model = BartModel(config)
        self.classification_head = BartClassificationHead(
Lysandre's avatar
Lysandre committed
1151
1152
1153
1154
            config.d_model,
            config.d_model,
            config.num_labels,
            config.classif_dropout,
Sam Shleifer's avatar
Sam Shleifer committed
1155
1156
1157
1158
1159
        )
        self.model._init_weights(self.classification_head.dense)
        self.model._init_weights(self.classification_head.out_proj)

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1160
1161
1162
1163
1164
1165
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=Seq2SeqSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Sam Shleifer's avatar
Sam Shleifer committed
1166
1167
1168
1169
1170
1171
1172
1173
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        labels=None,
1174
        use_cache=None,
1175
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1176
        output_hidden_states=None,
1177
        return_dict=None,
Sam Shleifer's avatar
Sam Shleifer committed
1178
1179
1180
1181
1182
1183
1184
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
1185
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1186
1187
1188
        if labels is not None:
            use_cache = False

1189
        outputs = self.model(
Sam Shleifer's avatar
Sam Shleifer committed
1190
1191
1192
1193
1194
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
1195
            use_cache=use_cache,
1196
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1197
            output_hidden_states=output_hidden_states,
1198
            return_dict=return_dict,
Sam Shleifer's avatar
Sam Shleifer committed
1199
1200
        )
        x = outputs[0]  # last hidden state
1201
        eos_mask = input_ids.eq(self.config.eos_token_id)
Sam Shleifer's avatar
Sam Shleifer committed
1202
1203
1204
1205
        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
        logits = self.classification_head(sentence_representation)
1206
1207
1208

        loss = None
        if labels is not None:
1209
1210
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
Sam Shleifer's avatar
Sam Shleifer committed
1211

1212
        if not return_dict:
1213
1214
1215
1216
1217
1218
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqSequenceClassifierOutput(
            loss=loss,
            logits=logits,
1219
            past_key_values=outputs.past_key_values,
1220
1221
1222
1223
1224
1225
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
1226
1227


Suraj Patil's avatar
Suraj Patil committed
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
@add_start_docstrings(
    """BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    BART_START_DOCSTRING,
)
class BartForQuestionAnswering(PretrainedBartModel):
    def __init__(self, config):
        super().__init__(config)

        config.num_labels = 2
        self.num_labels = config.num_labels

        self.model = BartModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.model._init_weights(self.qa_outputs)

    @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
1246
1247
1248
1249
1250
1251
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=Seq2SeqQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Suraj Patil's avatar
Suraj Patil committed
1252
1253
1254
1255
1256
1257
1258
1259
1260
    def forward(
        self,
        input_ids,
        attention_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        start_positions=None,
        end_positions=None,
1261
        use_cache=None,
Suraj Patil's avatar
Suraj Patil committed
1262
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1263
        output_hidden_states=None,
1264
        return_dict=None,
Suraj Patil's avatar
Suraj Patil committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        """
1276
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1277
1278
        if start_positions is not None and end_positions is not None:
            use_cache = False
Suraj Patil's avatar
Suraj Patil committed
1279
1280
1281
1282
1283
1284
1285

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
1286
            use_cache=use_cache,
Suraj Patil's avatar
Suraj Patil committed
1287
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1288
            output_hidden_states=output_hidden_states,
1289
            return_dict=return_dict,
Suraj Patil's avatar
Suraj Patil committed
1290
1291
1292
1293
1294
1295
1296
1297
1298
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1299
        total_loss = None
Suraj Patil's avatar
Suraj Patil committed
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1316
        if not return_dict:
Lysandre's avatar
Lysandre committed
1317
1318
1319
1320
            output = (
                start_logits,
                end_logits,
            ) + outputs[1:]
1321
1322
1323
1324
1325
1326
            return ((total_loss,) + output) if total_loss is not None else output

        return Seq2SeqQuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
1327
            past_key_values=outputs.past_key_values,
1328
1329
1330
1331
1332
1333
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
Suraj Patil's avatar
Suraj Patil committed
1334
1335


1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
class SinusoidalPositionalEmbedding(nn.Embedding):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions, embedding_dim, padding_idx=None):
        super().__init__(num_positions, embedding_dim)
        if embedding_dim % 2 != 0:
            raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter):
        """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
Lysandre's avatar
Lysandre committed
1348
        The cos features are in the 2nd half of the vector. [dim // 2:]
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))  # This line breaks for odd n_pos
        out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        out.requires_grad = False
        return out

    @torch.no_grad()
    def forward(self, input_ids, use_cache=False):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input_ids.shape[:2]
        if use_cache:
            positions = input_ids.data.new(1, 1).fill_(seq_len - 1)  # called before slicing
        else:
            # starts at 0, ends at 1-seq_len
            positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions)